Source code for gala.dynamics.plot

# Third-party
import numpy as np

__all__ = ['plot_projections']

def _get_axes(dim, subplots_kwargs=dict()):
    """
    Parameters
    ----------
    dim : int
        Dimensionality of the orbit.
    subplots_kwargs : dict (optional)
        Dictionary of kwargs passed to :func:`~matplotlib.pyplot.subplots`.
    """

    import matplotlib.pyplot as plt

    if dim > 1:
        n_panels = int(dim * (dim - 1) / 2)
    else:
        n_panels = 1
    figsize = subplots_kwargs.pop('figsize', (4*n_panels, 4))
    fig, axes = plt.subplots(1, n_panels, figsize=figsize,
                             **subplots_kwargs)

    if n_panels == 1:
        axes = [axes]

    else:
        axes = axes.flat

    return axes

[docs]def plot_projections(x, relative_to=None, autolim=True, axes=None, subplots_kwargs=dict(), labels=None, plot_function=None, **kwargs): """ Given N-dimensional quantity, ``x``, make a figure containing 2D projections of all combinations of the axes. Parameters ---------- x : array_like Array of values. ``axis=0`` is assumed to be the dimensionality, ``axis=1`` is the time axis. See :ref:`shape-conventions` for more information. relative_to : bool (optional) Plot the values relative to this value or values. autolim : bool (optional) Automatically set the plot limits to be something sensible. axes : array_like (optional) Array of matplotlib Axes objects. subplots_kwargs : dict (optional) Dictionary of kwargs passed to :func:`~matplotlib.pyplot.subplots`. labels : iterable (optional) List or iterable of axis labels as strings. They should correspond to the dimensions of the input orbit. plot_function : callable (optional) The ``matplotlib`` plot function to use. By default, this is :func:`~matplotlib.pyplot.scatter`, but can also be, e.g., :func:`~matplotlib.pyplot.plot`. **kwargs All other keyword arguments are passed to the ``plot_function``. You can pass in any of the usual style kwargs like ``color=...``, ``marker=...``, etc. Returns ------- fig : `~matplotlib.Figure` """ # don't propagate changes back... x = np.array(x, copy=True) ndim = x.shape[0] # get axes object from arguments if axes is None: axes = _get_axes(dim=ndim, subplots_kwargs=subplots_kwargs) # if the quantities are relative if relative_to is not None: x -= relative_to # name of the plotting function plot_fn_name = plot_function.__name__ # automatically determine limits if autolim: lims = [] for i in range(ndim): max_,min_ = np.max(x[i]), np.min(x[i]) delta = max_ - min_ if delta == 0.: delta = 1. lims.append([min_ - delta*0.02, max_ + delta*0.02]) k = 0 for i in range(ndim): for j in range(ndim): if i >= j: continue # skip diagonal, upper triangle plot_func = getattr(axes[k], plot_fn_name) plot_func(x[i], x[j], **kwargs) if labels is not None: axes[k].set_xlabel(labels[i]) axes[k].set_ylabel(labels[j]) if autolim: axes[k].set_xlim(lims[i]) axes[k].set_ylim(lims[j]) k += 1 axes[0].figure.tight_layout() return axes[0].figure