Source code for gala.util

# coding: utf-8

""" General utilities. """

from __future__ import division, print_function

__author__ = "adrn <adrn@astro.columbia.edu>"

# Standard library
import collections
import sys
import multiprocessing

# Third-party
from astropy import log as logger
import numpy as np

__all__ = ['get_pool', 'rolling_window', 'atleast_2d',
           'assert_angles_allclose', 'assert_quantities_allclose']

class SerialPool(object):

    def close(self):
        return

    def map(self, function, tasks, callback=None):
        results = []
        for task in tasks:
            result = function(task)
            if callback is not None:
                callback(result)
            results.append(result)
        return results

[docs]def get_pool(mpi=False, threads=None, **kwargs): """ Get a pool object to pass to emcee for parallel processing. If mpi is False and threads is None, pool is None. Parameters ---------- mpi : bool Use MPI or not. If specified, ignores the threads kwarg. threads : int (optional) If mpi is False and threads is specified, use a Python multiprocessing pool with the specified number of threads. **kwargs Any other keyword arguments are passed through to the pool initializers. """ if mpi: from mpipool import MPIPool # Initialize the MPI pool pool = MPIPool(**kwargs) # Make sure the thread we're running on is the master if not pool.is_master(): pool.wait() sys.exit(0) logger.debug("Running with MPI...") elif threads is not None and threads > 1: logger.debug("Running with multiprocessing on {} cores..." .format(threads)) pool = multiprocessing.Pool(threads, **kwargs) else: logger.debug("Running serial...") pool = SerialPool(**kwargs) return pool
class ImmutableDict(collections.Mapping): @classmethod def from_dict(cls, somedict): return cls(**somedict) def __init__(self, **kwargs): self._dict = kwargs self._hash = None def __getitem__(self, key): return self._dict[key] def __len__(self): return len(self._dict) def __iter__(self): return iter(self._dict) def __hash__(self): if self._hash is None: self._hash = hash(frozenset(self._dict.items())) return self._hash def __eq__(self, other): return self._dict == other._dict def __repr__(self): return "<ImmutableDict {}>".format(self._dict.__repr__()) def __str__(self): return self._dict.__str__()
[docs]def rolling_window(arr, window_size, stride=1, return_idx=False): """ There is an example of an iterator for pure-Python objects in: http://stackoverflow.com/questions/6822725/rolling-or-sliding-window-iterator-in-python This is a rolling-window iterator Numpy arrays, with window size and stride control. See examples below for demos. Parameters ---------- arr : array_like Input numpy array. window_size : int Width of the window. stride : int (optional) Number of indices to advance the window each iteration step. return_idx : bool (optional) Whether to return the slice indices alone with the array segment. Examples -------- >>> a = np.array([1,2,3,4,5,6]) >>> for x in rolling_window(a, 3): ... print(x) [1 2 3] [2 3 4] [3 4 5] [4 5 6] >>> for x in rolling_window(a, 2, stride=2): ... print(x) [1 2] [3 4] [5 6] >>> for (i1,i2),x in rolling_window(a, 2, stride=2, return_idx=True): # doctest: +SKIP ... print(i1, i2, x) (0, 2, array([1, 2])) (2, 4, array([3, 4])) (4, 6, array([5, 6])) """ window_size = int(window_size) stride = int(stride) if window_size < 0 or stride < 1: raise ValueError arr_len = len(arr) if arr_len < window_size: if return_idx: yield (0,arr_len),arr else: yield arr ix1 = 0 while ix1 < arr_len: ix2 = ix1 + window_size result = arr[ix1:ix2] if return_idx: yield (ix1,ix2),result else: yield result if len(result) < window_size or ix2 >= arr_len: break ix1 += stride
[docs]def atleast_2d(*arys, **kwargs): """ View inputs as arrays with at least two dimensions. Parameters ---------- arys1, arys2, ... : array_like One or more array-like sequences. Non-array inputs are converted to arrays. Arrays that already have two or more dimensions are preserved. insert_axis : int (optional) Where to create a new axis if input array(s) have <2 dim. Returns ------- res, res2, ... : ndarray An array, or tuple of arrays, each with ``a.ndim >= 2``. Copies are avoided where possible, and views with two or more dimensions are returned. Examples -------- >>> atleast_2d(3.0) array([[ 3.]]) >>> x = np.arange(3.0) >>> atleast_2d(x) array([[ 0., 1., 2.]]) >>> atleast_2d(x, insert_axis=-1) array([[ 0.], [ 1.], [ 2.]]) >>> atleast_2d(x).base is x True >>> atleast_2d(1, [1, 2], [[1, 2]]) [array([[1]]), array([[1, 2]]), array([[1, 2]])] """ insert_axis = kwargs.pop('insert_axis', 0) slc = [slice(None)]*2 slc[insert_axis] = None res = [] for ary in arys: ary = np.asanyarray(ary) if len(ary.shape) == 0: result = ary.reshape(1, 1) elif len(ary.shape) == 1: result = ary[slc] else: result = ary res.append(result) if len(res) == 1: return res[0] else: return res
############################################################################## # Testing #
[docs]def assert_angles_allclose(x, y, **kwargs): """ Like numpy's assert_allclose, but for angles (in radians). """ c2 = (np.sin(x)-np.sin(y))**2 + (np.cos(x)-np.cos(y))**2 diff = np.arccos((2.0 - c2)/2.0) # a = b = 1 assert np.allclose(diff, 0.0, **kwargs)
[docs]def assert_quantities_allclose(x, y, **kwargs): """ Like numpy's assert_allclose, but for quantities. """ y = y.to(x.unit).value x = x.value assert np.allclose(x, y, **kwargs)