Source code for gala.dynamics.orbit

import warnings

# Third-party
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import numpy as np
from scipy.signal import argrelmax

# Project
from gala.logging import logger

from ..io import quantity_from_hdf5, quantity_to_hdf5
from ..units import DimensionlessUnitSystem, UnitSystem, dimensionless
from ..util import atleast_2d
from .core import PhaseSpacePosition
from .plot import plot_projections
from .util import peak_to_peak_period

__all__ = ["Orbit"]


[docs] class Orbit(PhaseSpacePosition): """ Represents an orbit: positions and velocities (conjugate momenta) as a function of time. The class can be instantiated with Astropy representation objects (e.g., :class:`~astropy.coordinates.CartesianRepresentation`), Astropy :class:`~astropy.units.Quantity` objects, or plain Numpy arrays. If passing in Quantity or Numpy array instances for both position and velocity, they are assumed to be Cartesian. Array inputs are interpreted as dimensionless quantities. The input position and velocity objects can have an arbitrary number of (broadcastable) dimensions. For Quantity or array inputs, the first axes have special meaning: - ``axis=0`` is the coordinate dimension (e.g., x, y, z) - ``axis=1`` is the time dimension So if the input position array, ``pos``, has shape ``pos.shape = (3, 100)``, this would be a 3D orbit at 100 times (``pos[0]`` is ``x``, ``pos[1]``` is ``y``, etc.). For representing multiple orbits, the position array could have 3 axes, e.g., it might have shape `pos.shape = (3, 100, 8)`, where this is interpreted as a 3D position at 100 times for 8 different orbits. The same is true for velocity. The position and velocity arrays must have the same shape. If a time argument is specified, the position and velocity arrays must have the same number of timesteps as the length of the time object:: len(t) == pos.shape[1] Parameters ---------- pos : representation, quantity_like, or array_like Positions. If a numpy array (e.g., has no units), this will be stored as a dimensionless :class:`~astropy.units.Quantity`. See the note above about the assumed meaning of the axes of this object. vel : differential, quantity_like, or array_like Velocities. If a numpy array (e.g., has no units), this will be stored as a dimensionless :class:`~astropy.units.Quantity`. See the note above about the assumed meaning of the axes of this object. t : array_like, :class:`~astropy.units.Quantity` (optional) Array of times. If a numpy array (e.g., has no units), this will be stored as a dimensionless :class:`~astropy.units.Quantity`. hamiltonian : `~gala.potential.Hamiltonian` (optional) The Hamiltonian that the orbit was integrated in. """ def __init__(self, pos, vel, t=None, hamiltonian=None, potential=None, frame=None): super().__init__(pos=pos, vel=vel) if self.pos.ndim < 1: self.pos = self.pos.reshape(1) self.vel = self.vel.reshape(1) # TODO: check that Hamiltonian ndim is consistent with here if t is not None: t = np.atleast_1d(t) if self.pos.shape[0] != len(t): raise ValueError( "Position and velocity must have the same " "length along axis=1 as the length of the " "time array {} vs {}".format(len(t), self.pos.shape[0]) ) if not hasattr(t, "unit"): t = t * u.one self.t = t if hamiltonian is not None: self.potential = hamiltonian.potential self.frame = hamiltonian.frame else: self.potential = potential self.frame = frame def __getitem__(self, slice_): if isinstance(slice_, np.ndarray) or isinstance(slice_, list): slice_ = (slice_,) try: slice_ = tuple(slice_) except TypeError: slice_ = (slice_,) kw = dict() if self.t is not None: kw["t"] = self.t[slice_[0]] pos = self.pos[slice_] vel = self.vel[slice_] # if one time is sliced out, return a phasespaceposition try: int_tslice = int(slice_[0]) except TypeError: int_tslice = None if int_tslice is not None: return PhaseSpacePosition(pos=pos, vel=vel, frame=self.frame) else: return self.__class__( pos=pos, vel=vel, potential=self.potential, frame=self.frame, **kw ) @property def hamiltonian(self): if self.potential is None or self.frame is None: return None try: return self._hamiltonian except AttributeError: from gala.potential import Hamiltonian self._hamiltonian = Hamiltonian(potential=self.potential, frame=self.frame) return self._hamiltonian
[docs] def w(self, units=None): """ This returns a single array containing the phase-space positions. Parameters ---------- units : `~gala.units.UnitSystem` (optional) The unit system to represent the position and velocity in before combining into the full array. Returns ------- w : `~numpy.ndarray` A numpy array of all positions and velocities, without units. Will have shape ``(2*ndim, ...)``. """ if units is None: if self.hamiltonian is None: units = dimensionless else: units = self.hamiltonian.units return super().w(units=units)
# ------------------------------------------------------------------------ # Convert from Cartesian to other representations #
[docs] def represent_as(self, new_pos, new_vel=None): """ Represent the position and velocity of the orbit in an alternate coordinate system. Supports any of the Astropy coordinates representation classes. Parameters ---------- new_pos : :class:`~astropy.coordinates.BaseRepresentation` The type of representation to generate. Must be a class (not an instance), or the string name of the representation class. new_vel : :class:`~astropy.coordinates.BaseDifferential` (optional) Class in which any velocities should be represented. Must be a class (not an instance), or the string name of the differential class. If None, uses the default differential for the new position class. Returns ------- new_orbit : `gala.dynamics.Orbit` """ kw = dict() if self.t is not None: kw["t"] = self.t o = super().represent_as(new_pos=new_pos, new_vel=new_vel) return self.__class__(pos=o.pos, vel=o.vel, hamiltonian=self.hamiltonian, **kw)
# ------------------------------------------------------------------------ # Shape and size # ------------------------------------------------------------------------ @property def ntimes(self): return self.shape[0] @property def norbits(self): if len(self.shape) < 2: return 1 else: return self.shape[1]
[docs] def reshape(self, new_shape): """ Reshape the underlying position and velocity arrays. """ kw = dict() if self.t is not None: kw["t"] = self.t return self.__class__( pos=self.pos.reshape(new_shape), vel=self.vel.reshape(new_shape), hamiltonian=self.hamiltonian, **kw, )
# ------------------------------------------------------------------------ # Input / output #
[docs] def to_hdf5(self, f): """ Serialize this object to an HDF5 file. Requires ``h5py``. Parameters ---------- f : str, :class:`h5py.File` Either the filename or an open HDF5 file. """ f = super().to_hdf5(f) if self.potential is not None: import yaml from ..potential.potential.io import to_dict f["potential"] = yaml.dump(to_dict(self.potential)).encode("utf-8") if self.t is not None: quantity_to_hdf5(f, "time", self.t) return f
[docs] @classmethod def from_hdf5(cls, f): """ Load an object from an HDF5 file. Requires ``h5py``. Parameters ---------- f : str, :class:`h5py.File` Either the filename or an open HDF5 file. """ # TODO: this is duplicated code from PhaseSpacePosition if isinstance(f, str): import h5py f = h5py.File(f, mode="r") close = True else: close = False pos = quantity_from_hdf5(f["pos"]) vel = quantity_from_hdf5(f["vel"]) time = None if "time" in f: time = quantity_from_hdf5(f["time"]) frame = None if "frame" in f: g = f["frame"] frame_mod = g.attrs["module"] frame_cls = g.attrs["class"] frame_units = [u.Unit(x.decode("utf-8")) for x in g["units"]] if u.dimensionless_unscaled in frame_units: units = DimensionlessUnitSystem() else: units = UnitSystem(*frame_units) pars = dict() for k in g["parameters"]: pars[k] = quantity_from_hdf5(g["parameters/" + k]) exec("from {0} import {1}".format(frame_mod, frame_cls)) frame_cls = eval(frame_cls) frame = frame_cls(units=units, **pars) potential = None if "potential" in f: import yaml from ..potential.potential.io import from_dict _dict = yaml.load(f["potential"][()].decode("utf-8"), Loader=yaml.Loader) potential = from_dict(_dict) if close: f.close() return cls(pos=pos, vel=vel, t=time, frame=frame, potential=potential)
[docs] def orbit_gen(self): """ Generator for iterating over each orbit. """ if self.norbits == 1: yield self else: for i in range(self.norbits): yield self[:, i]
# ------------------------------------------------------------------------ # Computed dynamical quantities #
[docs] def potential_energy(self, potential=None): r""" The potential energy *per unit mass*: .. math:: E_\Phi = \Phi(\boldsymbol{q}) Returns ------- E : :class:`~astropy.units.Quantity` The potential energy. """ if self.hamiltonian is None and potential is None: raise ValueError( "To compute the potential energy, a potential" " object must be provided!" ) if potential is None: potential = self.hamiltonian.potential return super().potential_energy(potential)
[docs] def energy(self, hamiltonian=None): r""" The total energy *per unit mass*: Parameters ---------- hamiltonian : `gala.potential.Hamiltonian`, `gala.potential.PotentialBase` instance The Hamiltonian object to evaluate the energy. If a potential is passed in, this assumes a static reference frame. Returns ------- E : :class:`~astropy.units.Quantity` The total energy. """ if self.hamiltonian is None and hamiltonian is None: raise ValueError( "To compute the total energy, a hamiltonian" " object must be provided!" ) if hamiltonian is None: hamiltonian = self.hamiltonian else: from gala.potential import Hamiltonian hamiltonian = Hamiltonian(hamiltonian) return hamiltonian(self)
def _max_helper(self, arr, approximate=False): """ Helper function for computing extrema (apocenter, pericenter, z_height) and times of extrema. Parameters ---------- arr : `numpy.ndarray` """ assert self.norbits == 1 assert self.t[-1] > self.t[0] # time must increase _ix = argrelmax(arr.value, mode="wrap")[0] _ix = _ix[(_ix != 0) & (_ix != (len(arr) - 1))] # remove edges t = self.t.value approx_arr = arr[_ix] approx_t = t[_ix] if approximate: return approx_arr, approx_t * self.t.unit better_times = np.zeros(_ix.shape, dtype=float) better_arr = np.zeros(_ix.shape, dtype=float) for i, j in enumerate(_ix): tvals = t[j - 1 : j + 2] rvals = arr[j - 1 : j + 2].value coeffs = np.polynomial.polynomial.polyfit(tvals, rvals, 2) better_times[i] = (-coeffs[1]) / (2 * coeffs[2]) better_arr[i] = ( (coeffs[2] * better_times[i] ** 2) + (coeffs[1] * better_times[i]) + coeffs[0] ) return better_arr * arr.unit, better_times * self.t.unit def _max_return_helper(self, vals, times, return_times, reduce): if return_times: if len(vals) == 1: return vals[0], times[0] else: return vals, times elif reduce: return u.Quantity(vals).reshape(self.shape[1:]) else: return u.Quantity(vals)
[docs] def pericenter(self, return_times=False, func=np.mean, approximate=False): """ Estimate the pericenter(s) of the orbit by identifying local minima in the spherical radius, fitting a parabola around these local minima and then solving this parabola to find the pericenter(s). By default, this returns the mean of all local minima (pericenters). To get, e.g., the minimum pericenter, pass in ``func=np.min``. To get all pericenters, pass in ``func=None``. Parameters ---------- func : func (optional) A function to evaluate on all of the identified pericenter times. return_times : bool (optional) Also return the pericenter times. approximate : bool (optional) Compute an approximate pericenter by skipping interpolation. Returns ------- peri : float, :class:`~numpy.ndarray` Either a single number or an array of pericenters. times : :class:`~numpy.ndarray` (optional, see ``return_times``) If ``return_times=True``, also returns an array of the pericenter times. """ if return_times and func is not None: raise ValueError( "Cannot return times if reducing pericenters " "using an input function. Pass `func=None` if " "you want to return all individual pericenters " "and times." ) if func is None: reduce = False func = lambda x: x # noqa else: reduce = True # time must increase if self.t[-1] < self.t[0]: self = self[::-1] vals = [] times = [] for orbit in self.orbit_gen(): v, t = orbit._max_helper( -orbit.physicsspherical.r, approximate=approximate # pericenter ) vals.append(func(-v)) # negative for pericenter times.append(t) return self._max_return_helper(vals, times, return_times, reduce)
[docs] def apocenter(self, return_times=False, func=np.mean, approximate=False): """ Estimate the apocenter(s) of the orbit by identifying local maxima in the spherical radius, fitting a parabola around these local maxima and then solving this parabola to find the apocenter(s). By default, this returns the mean of all local maxima (apocenters). To get, e.g., the largest apocenter, pass in ``func=np.max``. To get all apocenters, pass in ``func=None``. Parameters ---------- func : func (optional) A function to evaluate on all of the identified apocenter times. return_times : bool (optional) Also return the apocenter times. approximate : bool (optional) Compute an approximate apocenter by skipping interpolation. Returns ------- apo : float, :class:`~numpy.ndarray` Either a single number or an array of apocenters. times : :class:`~numpy.ndarray` (optional, see ``return_times``) If ``return_times=True``, also returns an array of the apocenter times. """ if return_times and func is not None: raise ValueError( "Cannot return times if reducing apocenters " "using an input function. Pass `func=None` if " "you want to return all individual apocenters " "and times." ) if func is None: reduce = False func = lambda x: x # noqa else: reduce = True # time must increase if self.t[-1] < self.t[0]: self = self[::-1] vals = [] times = [] for orbit in self.orbit_gen(): v, t = orbit._max_helper( orbit.physicsspherical.r, approximate=approximate # apocenter ) vals.append(func(v)) times.append(t) return self._max_return_helper(vals, times, return_times, reduce)
[docs] def guiding_radius(self, potential=None, t=0.0, **root_kwargs): """ Compute the guiding-center radius Parameters ---------- potential : `gala.potential.PotentialBase` subclass instance (optional) The potential to compute the guiding radius in. t : quantity-like (optional) Time. **root_kwargs Any additional keyword arguments are passed to `~scipy.optimize.root`. Returns ------- Rg : :class:`~astropy.units.Quantity` Guiding-center radius. """ from .core import _guiding_radius_helper if potential is None: potential = self.potential if potential is None: raise ValueError( "You must specify a potential if it is not already defined on the orbit" ) R0s = np.atleast_1d( np.mean(self.cylindrical.rho).decompose(potential.units).value ) Lzs = self.angular_momentum()[2].decompose(potential.units).value mean_Lzs = np.atleast_1d(np.mean(Lzs, axis=0)) check = np.abs(np.std(Lzs, axis=0) / mean_Lzs) > 1e-8 if np.any(check): warnings.warn( f"{check.sum()} orbits do not have constant Lz (see orbits at indices: " f"{list(np.where(check)[0][:10])}, ...). Are you sure you are using an" " axisymmetric potential?", RuntimeWarning, ) Rgs = _guiding_radius_helper(R0s, mean_Lzs, potential, t=t, **root_kwargs) return Rgs.reshape(self.shape[1:]) * potential.units["length"]
[docs] def zmax(self, return_times=False, func=np.mean, approximate=False): """ Estimate the maximum ``z`` height of the orbit by identifying local maxima in the absolute value of the ``z`` position, fitting a parabola around these local maxima and then solving this parabola to find the maximum ``z`` height. By default, this returns the mean of all local maxima. To get, e.g., the largest ``z`` excursion, pass in ``func=np.max``. To get all ``z`` maxima, pass in ``func=None``. Parameters ---------- func : func (optional) A function to evaluate on all of the identified z maximum times. return_times : bool (optional) Also return the times of maximum. approximate : bool (optional) Compute approximate values by skipping interpolation. Returns ------- zs : float, :class:`~numpy.ndarray` Either a single number or an array of maximum z heights. times : :class:`~numpy.ndarray` (optional, see ``return_times``) If ``return_times=True``, also returns an array of the apocenter times. """ if return_times and func is not None: raise ValueError( "Cannot return times if reducing " "using an input function. Pass `func=None` if " "you want to return all individual values " "and times." ) if func is None: reduce = False func = lambda x: x # noqa else: reduce = True # time must increase if self.t[-1] < self.t[0]: self = self[::-1] vals = [] times = [] for orbit in self.orbit_gen(): v, t = orbit._max_helper( np.abs(orbit.cylindrical.z), approximate=approximate ) vals.append(func(v)) times.append(t) return self._max_return_helper(vals, times, return_times, reduce)
[docs] def eccentricity(self, **kw): r""" Returns the eccentricity computed from the mean apocenter and mean pericenter. .. math:: e = \frac{r_{\rm apo} - r_{\rm per}}{r_{\rm apo} + r_{\rm per}} Parameters ---------- **kw Any keyword arguments passed to ``apocenter()`` and ``pericenter()``. For example, ``approximate=True``. Returns ------- ecc : float The orbital eccentricity. """ ra = self.apocenter(**kw) rp = self.pericenter(**kw) return (ra - rp) / (ra + rp)
[docs] def estimate_period(self, radial=False): """ Estimate the period of the orbit. By default, computes the radial period. If ``radial==False``, this returns period estimates for each dimension of the orbit. Parameters ---------- radial : bool (optional) What period to estimate. If ``True``, estimates the radial period. If ``False``, estimates period in each dimension, e.g., if the orbit is 3D, along x, y, and z. Returns ------- periods : `~astropy.table.QTable` The estimated orbital periods for each phase-space component, for each orbit. """ if self.t is None: raise ValueError( "To compute the period, a time array is needed. " "Specify a time array when creating this object." ) if radial: # TODO: Remove this after deprecation cycle import warnings from gala.util import GalaDeprecationWarning warnings.warn( "Passing radial=True in estimate_period() is now deprecated. " "This method now returns period estimates for all orbital " "components. If you want to get just the radial period, use " "orbits.physicsspherical.estimate_period() instead.", GalaDeprecationWarning, ) r = self.physicsspherical.r.value if self.norbits == 1: T = u.Quantity(peak_to_peak_period(self.t, r)) else: T = u.Quantity( [peak_to_peak_period(self.t, r[:, n]) for n in range(r.shape[1])] ) return T else: periods = {} for k in self.pos_components.keys(): q = getattr(self, k).value if self.norbits == 1: T = u.Quantity(peak_to_peak_period(self.t, q)) else: T = u.Quantity( [ peak_to_peak_period(self.t, q[:, n]) for n in range(q.shape[1]) ] ) periods[k] = np.atleast_1d(T) return at.QTable(periods)
# ------------------------------------------------------------------------ # Misc. useful methods # ------------------------------------------------------------------------
[docs] def circulation(self): """ Determine which axes the Orbit circulates around by checking whether there is a change of sign of the angular momentum about an axis. Returns a 2D array with ``ndim`` integers per orbit point. If a box orbit, all integers will be 0. A 1 indicates circulation about the corresponding axis. TODO: clockwise / counterclockwise? For example, for a single 3D orbit: - Box and boxlet = [0, 0, 0] - z-axis (short-axis) tube = [0, 0, 1] - x-axis (long-axis) tube = [1, 0, 0] Returns ------- circulation : :class:`numpy.ndarray` An array that specifies whether there is circulation about any of the axes of the input orbit. For a single orbit, will return a 1D array, but for multiple orbits, the shape will be ``(3, norbits)``. """ L = self.angular_momentum() # if only 2D, add another empty axis if L.ndim == 2: single_orbit = True L = L[..., None] else: single_orbit = False ndim, ntimes, norbits = L.shape # initial angular momentum L0 = L[:, 0] # see if at any timestep the sign has changed circ = np.ones((ndim, norbits)) for ii in range(ndim): cnd = (np.sign(L0[ii]) != np.sign(L[ii, 1:])) | ( np.abs(L[ii, 1:]).value < 1e-13 ) ix = np.atleast_1d(np.any(cnd, axis=0)) circ[ii, ix] = 0 circ = circ.astype(int) if single_orbit: return circ.reshape((ndim,)) else: return circ
[docs] def align_circulation_with_z(self, circulation=None): """ If the input orbit is a tube orbit, this function aligns the circulation axis with the z axis and returns a copy. Parameters ---------- circulation : array_like (optional) Array of bits that specify the axis about which the orbit circulates. If not provided, will compute this using :meth:`~gala.dynamics.Orbit.circulation`. See that method for more information. Returns ------- orb : :class:`~gala.dynamics.Orbit` A copy of the original orbit object with circulation aligned with the z axis. """ if circulation is None: circulation = self.circulation() circulation = atleast_2d(circulation, insert_axis=1) cart = self.cartesian pos = cart.xyz vel = ( np.vstack( (cart.v_x.value[None], cart.v_y.value[None], cart.v_z.value[None]) ) * cart.v_x.unit ) if pos.ndim < 3: pos = pos[..., np.newaxis] vel = vel[..., np.newaxis] if circulation.shape[0] != self.ndim or circulation.shape[1] != pos.shape[2]: raise ValueError( "Shape of 'circulation' array should match the " "shape of the position/velocity (minus the time " "axis)." ) new_pos = pos.copy() new_vel = vel.copy() for n in range(pos.shape[2]): if circulation[2, n] == 1 or np.all(circulation[:, n] == 0): # already circulating about z or box orbit continue if sum(circulation[:, n]) > 1: logger.warning( "Circulation about multiple axes - are you sure " "the orbit has been integrated for long enough?" ) if circulation[0, n] == 1: circ = 0 elif circulation[1, n] == 1: circ = 1 else: raise RuntimeError("Should never get here...") new_pos[circ, :, n] = pos[2, :, n] new_pos[2, :, n] = pos[circ, :, n] new_vel[circ, :, n] = vel[2, :, n] new_vel[2, :, n] = vel[circ, :, n] return self.__class__( pos=new_pos.reshape(cart.xyz.shape), vel=new_vel.reshape(cart.xyz.shape), t=self.t, hamiltonian=self.hamiltonian, )
[docs] def plot(self, components=None, units=None, auto_aspect=True, **kwargs): """ Plot the positions in all projections. This is a wrapper around `~gala.dynamics.plot_projections` for fast access and quick visualization. All extra keyword arguments are passed to that function (the docstring for this function is included here for convenience). Parameters ---------- components : iterable (optional) A list of component names (strings) to plot. By default, this is the Cartesian positions ``['x', 'y', 'z']``. To plot Cartesian velocities, pass in the velocity component names ``['v_x', 'v_y', 'v_z']``. If the representation is different, the component names will be different. For example, for a Cylindrical representation, the components are ``['rho', 'phi', 'z']`` and ``['v_rho', 'pm_phi', 'v_z']``. units : `~astropy.units.UnitBase`, iterable (optional) A single unit or list of units to display the components in. auto_aspect : bool (optional) Automatically enforce an equal aspect ratio. relative_to : bool (optional) Plot the values relative to this value or values. autolim : bool (optional) Automatically set the plot limits to be something sensible. axes : array_like (optional) Array of matplotlib Axes objects. subplots_kwargs : dict (optional) Dictionary of kwargs passed to :func:`~matplotlib.pyplot.subplots`. labels : iterable (optional) List or iterable of axis labels as strings. They should correspond to the dimensions of the input orbit. plot_function : callable (optional) The ``matplotlib`` plot function to use. By default, this is :func:`~matplotlib.pyplot.scatter`, but can also be, e.g., :func:`~matplotlib.pyplot.plot`. **kwargs All other keyword arguments are passed to the ``plot_function``. You can pass in any of the usual style kwargs like ``color=...``, ``marker=...``, etc. Returns ------- fig : `~matplotlib.Figure` """ from gala.tests.optional_deps import HAS_MATPLOTLIB if not HAS_MATPLOTLIB: raise ImportError("matplotlib is required for visualization.") import matplotlib.pyplot as plt if components is None: if self.ndim == 1: # only a 1D orbit, so just plot time series components = ["t", self.pos.components[0]] else: components = self.pos.components x, labels = self._plot_prepare(components=components, units=units) kwargs.setdefault("plot_function", plt.plot) if kwargs["plot_function"] in [plt.plot, plt.scatter]: kwargs.setdefault("marker", "") kwargs.setdefault("labels", labels) if kwargs["plot_function"] == plt.plot: kwargs.setdefault("linestyle", "-") fig = plot_projections(x, **kwargs) if ( self.pos.get_name() == "cartesian" and all([not c.startswith("d_") for c in components]) and "t" not in components and auto_aspect ): for ax in fig.axes: ax.set(aspect="equal", adjustable="datalim") return fig
[docs] def plot_3d( self, components=None, units=None, auto_aspect=True, subplots_kwargs=None, **kwargs, ): """ Plot the specified 3D components. Parameters ---------- components : iterable (optional) A list of component names (strings) to plot. By default, this is the Cartesian positions ``['x', 'y', 'z']``. To plot Cartesian velocities, pass in the velocity component names ``['v_x', 'v_y', 'v_z']``. If the representation is different, the component names will be different. For example, for a Cylindrical representation, the components are ``['rho', 'phi', 'z']`` and ``['v_rho', 'pm_phi', 'v_z']``. units : `~astropy.units.UnitBase`, iterable (optional) A single unit or list of units to display the components in. auto_aspect : bool (optional) Automatically enforce an equal aspect ratio. ax : `matplotlib.axes.Axes` The matplotlib Axes object to draw on. subplots_kwargs : dict (optional) Dictionary of kwargs passed to :func:`~matplotlib.pyplot.subplots`. labels : iterable (optional) List or iterable of axis labels as strings. They should correspond to the dimensions of the input orbit. plot_function : str (optional) The ``matplotlib`` plot function to use. By default, this is 'plot' but can also be, e.g., 'scatter'. **kwargs All other keyword arguments are passed to the ``plot_function``. You can pass in any of the usual style kwargs like ``color=...``, ``marker=...``, etc. Returns ------- fig : `~matplotlib.Figure` """ from gala.tests.optional_deps import HAS_MATPLOTLIB if not HAS_MATPLOTLIB: raise ImportError("matplotlib is required for visualization.") import matplotlib.pyplot as plt from mpl_toolkits import mplot3d # noqa if components is None: components = self.pos.components if subplots_kwargs is None: subplots_kwargs = dict() if len(components) != 3: raise ValueError(f"The number of components ({len(components)}) must be 3") x, labels = self._plot_prepare(components=components, units=units) kwargs.setdefault("marker", "") kwargs.setdefault("linestyle", kwargs.pop("ls", "-")) plot_function_name = kwargs.pop("plot_function", "plot") ax = kwargs.pop("ax", None) subplots_kwargs.setdefault("constrained_layout", True) if ax is None: fig, ax = plt.subplots( figsize=(6, 6), subplot_kw=dict(projection="3d"), **subplots_kwargs ) else: fig = ax.figure plot_function = getattr(ax, plot_function_name) if x[0].ndim > 1: for n in range(x[0].shape[1]): plot_function(*[xx[:, n] for xx in x], **kwargs) else: plot_function(*x, **kwargs) ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) ax.set_zlabel(labels[2]) if ( self.pos.get_name() == "cartesian" and all([not c.startswith("d_") for c in components]) and "t" not in components and auto_aspect ): for ax in fig.axes: ax.set(aspect="auto", adjustable="datalim") return fig, ax
[docs] def animate( self, components=None, units=None, stride=1, segment_nsteps=10, underplot_full_orbit=True, show_time=True, marker_style=None, segment_style=None, FuncAnimation_kwargs=None, orbit_plot_kwargs=None, axes=None, ): """ Animate an orbit or collection of orbits. Parameters ---------- components : iterable (optional) A list of component names (strings) to plot. By default, this is the Cartesian positions ``['x', 'y', 'z']``. To plot Cartesian velocities, pass in the velocity component names ``['v_x', 'v_y', 'v_z']``. If the representation is different, the component names will be different. For example, for a Cylindrical representation, the components are ``['rho', 'phi', 'z']`` and ``['v_rho', 'pm_phi', 'v_z']``. units : `~astropy.units.UnitBase`, iterable (optional) A single unit or list of units to display the components in. stride : int (optional) How often to draw a new frame, in terms of orbit timesteps. segment_nsteps : int (optional) How many timesteps to draw in an orbit segment trailing the timestep marker. Set this to 0 or None to disable. underplot_full_orbit : bool (optional) Controls whether to under-plot the full orbit as a thin line. show_time : bool (optional) Controls whether to show a label of the current timestep marker_style : dict or list of dict (optional) Matplotlib style arguments passed to `matplotlib.pyplot.plot` that control the plot style of the timestep marker. If a single dict is passed then the marker_style is applied to all orbits. If a list of dicts is passed then each dict will be applied to each orbit. segment_style : dict (optional) Matplotlib style arguments passed to `matplotlib.pyplot.plot` that control the plot style of the orbit segment. If a single dict is passed then the segment_style is applied to all orbits. If a list of dicts is passed then each dict will be applied to each orbit. FuncAnimation_kwargs : dict (optional) Keyword arguments passed through to `matplotlib.animation.FuncAnimation`. orbit_plot_kwargs : dict (optional) Keyword arguments passed through to `gala.dynamics.Orbit.plot`. axes : `matplotlib.axes.Axes` (optional) Where to draw the orbit. Returns ------- fig : `matplotlib.figure.Figure` anim : `matplotlib.animation.FuncAnimation` """ from gala.tests.optional_deps import HAS_MATPLOTLIB if not HAS_MATPLOTLIB: raise ImportError("matplotlib is required for visualization.") from matplotlib.animation import FuncAnimation if components is None: if self.ndim == 1: # only a 1D orbit, so just plot time series components = ["t", self.pos.components[0]] else: components = self.pos.components # Extract the relevant components, in the given unit system xs, _ = self._plot_prepare(components=components, units=units) xs = [atleast_2d(xx, insert_axis=1) for xx in xs] # Figure out which components to plot on which axes data_paired = [] for i in range(len(xs)): for j in range(len(xs)): if i >= j: continue # skip diagonal, upper triangle data_paired.append((xs[i], xs[j])) if FuncAnimation_kwargs is None: FuncAnimation_kwargs = dict() if orbit_plot_kwargs is None: orbit_plot_kwargs = dict() orbit_plot_kwargs.setdefault("zorder", 1) orbit_plot_kwargs.setdefault("color", "#aaaaaa") orbit_plot_kwargs.setdefault("linewidth", "1") orbit_plot_kwargs.setdefault("axes", axes) if marker_style is None: marker_style = [dict() for _ in range(self.norbits)] # if a single dict is passed then copy it into a list if isinstance(marker_style, dict): marker_style = [marker_style for _ in range(self.norbits)] # otherwise ensure the list is the right length elif len(marker_style) != self.norbits: raise ValueError( "Length of `marker_style` list must be equal to the number of orbits" ) for n in range(self.norbits): marker_style[n].setdefault("marker", "o") marker_style[n].setdefault("linestyle", marker_style[n].pop("ls", "None")) marker_style[n].setdefault("markersize", marker_style[n].pop("ms", 4.0)) marker_style[n].setdefault("color", marker_style[n].pop("c", "tab:red")) marker_style[n].setdefault("zorder", 100) if segment_style is None: segment_style = [dict() for _ in range(self.norbits)] # if a single dict is passed then copy it into a list if isinstance(segment_style, dict): segment_style = [segment_style for _ in range(self.norbits)] # otherwise ensure the list is the right length elif len(segment_style) != self.norbits: raise ValueError( "Length of `segment_style` list must be equal to the number of orbits" ) for n in range(self.norbits): segment_style[n].setdefault("marker", "None") segment_style[n].setdefault("linestyle", segment_style[n].pop("ls", "-")) segment_style[n].setdefault("linewidth", segment_style[n].pop("lw", 2.0)) segment_style[n].setdefault("color", segment_style[n].pop("c", "tab:blue")) segment_style[n].setdefault("zorder", 10) if segment_nsteps is None or segment_nsteps == 0: # HACK segment_style[n]["alpha"] = 0 # Use this to get a figure with axes with the right limits # Note: Labels are added by .plot() if not underplot_full_orbit: orbit_plot_kwargs["alpha"] = 0 fig = self.plot(components=components, units=units, **orbit_plot_kwargs) # Set up all of the (data-less) markers and line segments markers = [] segments = [] for n in range(self.norbits): _m = [] _s = [] for i in range(len(data_paired)): _m.append(fig.axes[i].plot([], [], **marker_style[n])[0]) _s.append(fig.axes[i].plot([], [], **segment_style[n])[0]) markers.append(_m) segments.append(_s) # record the time unit and set up data-less annotates if user wants timestep label if show_time: time_unit = self.t.unit times = [ fig.axes[i].annotate( "", xy=(0.98, 0.98), xycoords="axes fraction", ha="right", va="top" ) for i in range(len(data_paired)) ] def anim_func(n): i = max(0, n - segment_nsteps) for k in range(self.norbits): for j in range(len(data_paired)): markers[k][j].set_data( data_paired[j][0][n : n + 1, k], data_paired[j][1][n : n + 1, k] ) segments[k][j].set_data( data_paired[j][0][i : n + 1, k], data_paired[j][1][i : n + 1, k] ) if show_time: time_value = self.t[n : n + 1].value[0] for time in times: time.set_text(f"Time={time_value:1.1f} {time_unit}") artists = ( *[m for m in markers for x in m], *[s for s in segments for x in s], *times, ) else: artists = ( *[m for m in markers for x in m], *[s for s in segments for x in s], ) return artists anim = FuncAnimation( fig, anim_func, frames=np.arange(0, self.ntimes, stride), **FuncAnimation_kwargs, ) return fig, anim
[docs] def to_frame(self, frame, current_frame=None, **kwargs): """ Transform to a different reference frame. Parameters ---------- frame : `gala.potential.CFrameBase` The frame to transform to. current_frame : `gala.potential.CFrameBase` (optional) If the Orbit has no associated Hamiltonian, this specifies the current frame of the orbit. Returns ------- orbit : `gala.dynamics.Orbit` The orbit in the new reference frame. """ kw = kwargs.copy() # TODO: this short-circuit sux if current_frame is None: current_frame = self.frame if frame == current_frame and not kwargs: return self # TODO: need a better way to do this! from ..potential.frame.builtin import ConstantRotatingFrame for fr in [frame, current_frame, self.frame]: if isinstance(fr, ConstantRotatingFrame): if "t" not in kw: kw["t"] = self.t # TODO: this needs a re-write... psp = super().to_frame(frame, current_frame, **kw) return Orbit( pos=psp.pos, vel=psp.vel, t=self.t, frame=frame, potential=self.potential )
# ------------------------------------------------------------------------ # Compatibility with other packages #
[docs] def to_galpy_orbit(self, ro=None, vo=None): """Convert this object to a ``galpy.Orbit`` instance. Parameters ---------- ro : `astropy.units.Quantity` or `astropy.units.UnitBase` "Natural" length unit. vo : `astropy.units.Quantity` or `astropy.units.UnitBase` "Natural" velocity unit. Returns ------- galpy_orbit : `galpy.orbit.Orbit` """ from galpy.orbit import Orbit from galpy.util.config import __config__ as galpy_config if self.frame is not None: from ..potential import StaticFrame w = self.to_frame(StaticFrame(self.frame.units)) else: w = self if ro is None: ro = galpy_config.getfloat("normalization", "ro") ro = ro * u.kpc if vo is None: vo = galpy_config.getfloat("normalization", "vo") vo = vo * u.km / u.s # PhaseSpacePosition or Orbit: cyl = w.cylindrical R = cyl.rho.to_value(ro).T phi = cyl.phi.to_value(u.rad).T z = cyl.z.to_value(ro).T vR = cyl.v_rho.to_value(vo).T vT = (cyl.rho * cyl.pm_phi).to_value(vo, u.dimensionless_angles()).T vz = cyl.v_z.to_value(vo).T o = Orbit(np.array([R, vR, vT, z, vz, phi]).T, ro=ro, vo=vo) if w.t is not None: o.t = w.t.to_value(ro / vo) return o
[docs] @classmethod def from_galpy_orbit(self, galpy_orbit): """Create a Gala ``PhaseSpacePosition`` or ``Orbit`` instance from a ``galpy.Orbit`` instance. Parameters ---------- galpy_orbit : :class:`galpy.orbit.Orbit` Returns ------- orbit : :class:`~gala.dynamics.Orbit` """ ro = galpy_orbit._ro * u.kpc vo = galpy_orbit._vo * u.km / u.s ts = galpy_orbit.t rep = coord.CylindricalRepresentation( rho=galpy_orbit.R(ts) * ro, phi=galpy_orbit.phi(ts) * u.rad, z=galpy_orbit.z(ts) * ro, ) with u.set_enabled_equivalencies(u.dimensionless_angles()): dif = coord.CylindricalDifferential( d_rho=galpy_orbit.vR(ts) * vo, d_phi=galpy_orbit.vT(ts) * vo / rep.rho, d_z=galpy_orbit.vz(ts) * vo, ) t = galpy_orbit.t * ro / vo return Orbit(rep, dif, t=t)