# Standard library
import re
from collections import namedtuple
# Third-party
import astropy.coordinates as coord
import astropy.units as u
import numpy as np
from astropy.coordinates import representation as r
from ..io import quantity_from_hdf5, quantity_to_hdf5
from ..units import DimensionlessUnitSystem, UnitSystem, _greek_letters
from ..util import atleast_2d
# Project
from . import representation_nd as rep_nd
from .plot import plot_projections
__all__ = ["PhaseSpacePosition"]
_RepresentationMappingBase = namedtuple(
"RepresentationMapping", ("repr_name", "new_name", "default_unit")
)
class RepresentationMapping(_RepresentationMappingBase):
"""
This `~collections.namedtuple` is used to override the representation and
differential class component names in the `PhaseSpacePosition` and `Orbit`
classes.
"""
def __new__(cls, repr_name, new_name, default_unit="recommended"):
# this trick just provides some defaults
return super().__new__(cls, repr_name, new_name, default_unit)
class RegexRepresentationMapping(RepresentationMapping):
"""
A representation mapping that uses a regex to map the original attribute
name to the new attribute name.
"""
pass
[docs]
class PhaseSpacePosition:
representation_mappings = {
r.CartesianRepresentation: [RepresentationMapping("xyz", "xyz")],
r.SphericalCosLatDifferential: [
RepresentationMapping("d_lon_coslat", "pm_lon_coslat", u.mas / u.yr),
RepresentationMapping("d_lat", "pm_lat", u.mas / u.yr),
RepresentationMapping("d_distance", "radial_velocity"),
],
r.SphericalDifferential: [
RepresentationMapping("d_lon", "pm_lon", u.mas / u.yr),
RepresentationMapping("d_lat", "pm_lat", u.mas / u.yr),
RepresentationMapping("d_distance", "radial_velocity"),
],
r.PhysicsSphericalDifferential: [
RepresentationMapping("d_phi", "pm_phi", u.mas / u.yr),
RepresentationMapping("d_theta", "pm_theta", u.mas / u.yr),
RepresentationMapping("d_r", "radial_velocity"),
],
r.CartesianDifferential: [
RepresentationMapping("d_x", "v_x"),
RepresentationMapping("d_y", "v_y"),
RepresentationMapping("d_z", "v_z"),
RepresentationMapping("d_xyz", "v_xyz"),
],
r.CylindricalDifferential: [
RepresentationMapping("d_rho", "v_rho"),
RepresentationMapping("d_phi", "pm_phi"),
RepresentationMapping("d_z", "v_z"),
],
rep_nd.NDCartesianRepresentation: [RepresentationMapping("xyz", "xyz")],
rep_nd.NDCartesianDifferential: [
RepresentationMapping("d_xyz", "v_xyz"),
RegexRepresentationMapping("d_x([0-9])", "v_x{0}"),
],
}
representation_mappings[
r.UnitSphericalCosLatDifferential
] = representation_mappings[r.SphericalCosLatDifferential]
representation_mappings[r.UnitSphericalDifferential] = representation_mappings[
r.SphericalDifferential
]
def __init__(self, pos, vel=None, frame=None):
"""
Represents phase-space positions, i.e. positions and conjugate momenta
(velocities).
The class can be instantiated with Astropy representation objects (e.g.,
:class:`~astropy.coordinates.CartesianRepresentation`), Astropy
:class:`~astropy.units.Quantity` objects, or plain Numpy arrays.
If passing in representation objects, the default representation is
taken to be the class that is passed in.
If passing in Quantity or Numpy array instances for both position and
velocity, they are assumed to be Cartesian. Array inputs are interpreted
as dimensionless quantities. The input position and velocity objects can
have an arbitrary number of (broadcastable) dimensions. For Quantity or
array inputs, the first axis (0) has special meaning:
- `axis=0` is the coordinate dimension (e.g., x, y, z for Cartesian)
So if the input position array, `pos`, has shape `pos.shape = (3, 100)`,
this would represent 100 3D positions (`pos[0]` is `x`, `pos[1]` is `y`,
etc.). The same is true for velocity.
Parameters
----------
pos : representation, quantity_like, or array_like
Positions. If a numpy array (e.g., has no units), this will be
stored as a dimensionless :class:`~astropy.units.Quantity`. See
the note above about the assumed meaning of the axes of this object.
vel : differential, quantity_like, or array_like
Velocities. If a numpy array (e.g., has no units), this will be
stored as a dimensionless :class:`~astropy.units.Quantity`. See
the note above about the assumed meaning of the axes of this object.
frame : :class:`~gala.potential.FrameBase` (optional)
The reference frame of the input phase-space positions.
"""
if isinstance(pos, coord.Galactocentric):
pos = pos.data
if not isinstance(pos, coord.BaseRepresentation):
# assume Cartesian if not specified
if not hasattr(pos, "unit"):
pos = pos * u.one
# 3D coordinates get special treatment
ndim = pos.shape[0]
if ndim == 3:
# TODO: HACK: until this stuff is in astropy core
if isinstance(pos, coord.BaseRepresentation):
kw = [(k, getattr(pos, k)) for k in pos.components]
pos = getattr(coord, pos.__class__.__name__)(**kw)
else:
pos = coord.CartesianRepresentation(pos)
else:
pos = rep_nd.NDCartesianRepresentation(pos)
else:
ndim = 3
if vel is None:
if "s" not in pos.differentials:
raise TypeError(
"You must specify velocity data when creating "
"a {0} object.".format(self.__class__.__name__)
)
else:
vel = pos.differentials.get("s", None)
if not isinstance(vel, coord.BaseDifferential):
# assume representation is same as pos if not specified
if not hasattr(vel, "unit"):
vel = vel * u.one
if ndim == 3:
name = pos.__class__.get_name()
Diff = coord.representation.DIFFERENTIAL_CLASSES[name]
vel = Diff(*vel)
else:
Diff = rep_nd.NDCartesianDifferential
vel = Diff(vel)
# make sure shape is the same
if pos.shape != vel.shape:
raise ValueError(
"Position and velocity must have the same shape "
f"{pos.shape} vs. {vel.shape}"
)
from ..potential.frame import FrameBase
if frame is not None and not isinstance(frame, FrameBase):
raise TypeError(
"Input reference frame must be a FrameBase " "subclass instance."
)
self.pos = pos
self.vel = vel
self.frame = frame
self.ndim = ndim
def __getitem__(self, slyce):
return self.__class__(
pos=self.pos[slyce], vel=self.vel[slyce], frame=self.frame
)
[docs]
def get_components(self, which):
"""
Get the component name dictionary for the desired object.
The returned dictionary maps component names on this class to component
names on the desired object.
Parameters
----------
which : str
Can either be ``'pos'`` or ``'vel'`` to get the components for the
position or velocity object.
"""
mappings = self.representation_mappings.get(getattr(self, which).__class__, [])
old_to_new = dict()
for name in getattr(self, which).components:
for m in mappings:
if isinstance(m, RegexRepresentationMapping):
pattr = re.match(m.repr_name, name)
old_to_new[name] = m.new_name.format(*pattr.groups())
elif m.repr_name == name:
old_to_new[name] = m.new_name
mapping = dict()
for name in getattr(self, which).components:
mapping[old_to_new.get(name, name)] = name
return mapping
@property
def pos_components(self):
return self.get_components("pos")
@property
def vel_components(self):
return self.get_components("vel")
def _get_extra_mappings(self, which):
mappings = self.representation_mappings.get(getattr(self, which).__class__, [])
extra = dict()
for m in mappings:
if m.new_name not in self.get_components(which) and not isinstance(
m, RegexRepresentationMapping
):
extra[m.new_name] = m.repr_name
return extra
def __dir__(self):
"""
Override the builtin `dir` behavior to include representation and
differential names.
"""
dir_values = set(self.pos_components.keys())
dir_values |= set(self.vel_components.keys())
dir_values |= set(self._get_extra_mappings("pos").keys())
dir_values |= set(self._get_extra_mappings("vel").keys())
dir_values |= set(r.REPRESENTATION_CLASSES.keys())
dir_values |= set(super().__dir__())
return sorted(dir_values)
def __getattr__(self, attr):
"""
Allow access to attributes on the ``pos`` and ``vel`` representation and
differential objects.
"""
# Prevent infinite recursion here.
if attr.startswith("_"):
return self.__getattribute__(attr) # Raise AttributeError.
# TODO: with >3.5 support, can do:
# pos_comps = {**self.pos_components,
# **self._get_extra_mappings('pos')}
pos_comps = self.pos_components.copy()
pos_comps.update(self._get_extra_mappings("pos"))
if attr in pos_comps:
val = getattr(self.pos, pos_comps[attr])
return val
# TODO: with >3.5 support, can do:
# pos_comps = {**self.vel_components,
# **self._get_extra_mappings('vel')}
vel_comps = self.vel_components.copy()
vel_comps.update(self._get_extra_mappings("vel"))
if attr in vel_comps:
val = getattr(self.vel, vel_comps[attr])
return val
if attr in r.REPRESENTATION_CLASSES:
return self.represent_as(attr)
return self.__getattribute__(attr) # Raise AttributeError.
@property
def data(self):
return self.pos.with_differentials(self.vel)
# ------------------------------------------------------------------------
# Convert from Cartesian to other representations
#
[docs]
def represent_as(self, new_pos, new_vel=None):
"""
Represent the position and velocity of the orbit in an alternate
coordinate system. Supports any of the Astropy coordinates
representation classes.
Parameters
----------
new_pos : :class:`~astropy.coordinates.BaseRepresentation`
The type of representation to generate. Must be a class (not an
instance), or the string name of the representation class.
new_vel : :class:`~astropy.coordinates.BaseDifferential` (optional)
Class in which any velocities should be represented. Must be a class
(not an instance), or the string name of the differential class. If
None, uses the default differential for the new position class.
Returns
-------
new_psp : `gala.dynamics.PhaseSpacePosition`
"""
if self.ndim != 3:
raise ValueError("Can only change representation for " "ndim=3 instances.")
# get the name of the desired representation
if isinstance(new_pos, str):
pos_name = new_pos
else:
pos_name = new_pos.get_name()
if isinstance(new_vel, str):
vel_name = new_vel
elif new_vel is None:
vel_name = pos_name
else:
vel_name = new_vel.get_name()
Representation = coord.representation.REPRESENTATION_CLASSES[pos_name]
Differential = coord.representation.DIFFERENTIAL_CLASSES[vel_name]
new_pos = self.pos.represent_as(Representation)
new_vel = self.vel.represent_as(Differential, self.pos)
return self.__class__(pos=new_pos, vel=new_vel, frame=self.frame)
[docs]
def to_frame(self, frame, current_frame=None, **kwargs):
"""
Transform to a new reference frame.
Parameters
----------
frame : `~gala.potential.FrameBase`
The frame to transform to.
current_frame : `gala.potential.CFrameBase`
The current frame the phase-space position is in.
**kwargs
Any additional arguments are passed through to the individual frame
transformation functions (see:
`~gala.potential.frame.builtin.transformations`).
Returns
-------
psp : `gala.dynamics.PhaseSpacePosition`
The phase-space position in the new reference frame.
"""
from ..potential.frame.builtin import transformations as frame_trans
if self.frame is None and current_frame is None:
raise ValueError(
f"If no frame was specified when this {self} was "
"initialized, you must pass the current frame in "
"via the current_frame argument to transform to a "
"new frame."
)
elif self.frame is not None and current_frame is None:
current_frame = self.frame
name1 = current_frame.__class__.__name__.rstrip("Frame").lower()
name2 = frame.__class__.__name__.rstrip("Frame").lower()
func_name = f"{name1}_to_{name2}"
if not hasattr(frame_trans, func_name):
raise ValueError(
"Unsupported frame transformation: {} to {}".format(
current_frame, frame
)
)
else:
trans_func = getattr(frame_trans, func_name)
pos, vel = trans_func(current_frame, frame, self, **kwargs)
return PhaseSpacePosition(pos=pos, vel=vel, frame=frame)
[docs]
def to_coord_frame(self, frame, galactocentric_frame=None, **kwargs):
"""
Transform the orbit from Galactocentric, cartesian coordinates to
Heliocentric coordinates in the specified Astropy coordinate frame.
Parameters
----------
frame : :class:`~astropy.coordinates.BaseCoordinateFrame`
The frame instance specifying the desired output frame.
For example, :class:`~astropy.coordinates.ICRS`.
galactocentric_frame : :class:`~astropy.coordinates.Galactocentric`
This is the assumed frame that the position and velocity of this
object are in. The ``Galactocentric`` instand should have parameters
specifying the position and motion of the sun in the Galactocentric
frame, but no data.
Returns
-------
c : :class:`~astropy.coordinates.BaseCoordinateFrame`
An instantiated coordinate frame containing the positions and
velocities from this object transformed to the specified coordinate
frame.
"""
if self.ndim != 3:
raise ValueError("Can only change representation for " "ndim=3 instances.")
if galactocentric_frame is None:
galactocentric_frame = coord.Galactocentric()
pos_keys = list(self.pos_components.keys())
vel_keys = list(self.vel_components.keys())
if (
getattr(self, pos_keys[0]).unit == u.one
or getattr(self, vel_keys[0]).unit == u.one
):
raise u.UnitConversionError(
"Position and velocity must have "
"dimensioned units to convert to a "
"coordinate frame."
)
# first we need to turn the position into a Galactocentric instance
gc_c = galactocentric_frame.realize_frame(self.pos.with_differentials(self.vel))
c = gc_c.transform_to(frame)
return c
# Pseudo-backwards compatibility
[docs]
def w(self, units=None):
"""
This returns a single array containing the phase-space positions.
Parameters
----------
units : `~gala.units.UnitSystem` (optional)
The unit system to represent the position and velocity in
before combining into the full array.
Returns
-------
w : `~numpy.ndarray`
A numpy array of all positions and velocities, without units.
Will have shape ``(2*ndim, ...)``.
"""
if self.ndim == 3:
cart = self.cartesian
else:
cart = self
xyz = cart.xyz
d_xyz = cart.v_xyz
x_unit = xyz.unit
v_unit = d_xyz.unit
if (units is None or isinstance(units, DimensionlessUnitSystem)) and (
x_unit == u.one and v_unit == u.one
):
units = DimensionlessUnitSystem()
elif units is None:
raise ValueError("A UnitSystem must be provided.")
x = xyz.decompose(units).value
if x.ndim < 2:
x = atleast_2d(x, insert_axis=1)
v = d_xyz.decompose(units).value
if v.ndim < 2:
v = atleast_2d(v, insert_axis=1)
return np.vstack((x, v))
[docs]
@classmethod
def from_w(cls, w, units=None, **kwargs):
"""
Create a {name} object from a single array specifying positions
and velocities. This is mainly for backwards-compatibility and
it is not recommended for new users.
Parameters
----------
w : array_like
The array of phase-space positions.
units : `~gala.units.UnitSystem` (optional)
The unit system that the input position+velocity array, ``w``,
is represented in.
**kwargs
Any aditional keyword arguments passed to the class initializer.
Returns
-------
obj : `~gala.dynamics.{name}`
""".format(
name=cls.__name__
)
w = np.array(w)
ndim = w.shape[0] // 2
pos = w[:ndim]
vel = w[ndim:]
# TODO: this is bad form - UnitSystem should know what to do with a
# Dimensionless
if units is not None and not isinstance(units, DimensionlessUnitSystem):
units = UnitSystem(units)
pos = pos * units["length"]
vel = vel * units["length"] / units["time"] # from _core_units
return cls(pos=pos, vel=vel, **kwargs)
# ------------------------------------------------------------------------
# Input / output
#
[docs]
def to_hdf5(self, f):
"""
Serialize this object to an HDF5 file.
Requires ``h5py``.
Parameters
----------
f : str, :class:`h5py.File`
Either the filename or an open HDF5 file.
"""
if isinstance(f, str):
import h5py
f = h5py.File(f, mode="r")
if self.frame is not None:
frame_group = f.create_group("frame")
frame_group.attrs["module"] = self.frame.__module__
frame_group.attrs["class"] = self.frame.__class__.__name__
units = [str(x).encode("utf8") for x in self.frame.units.to_dict().values()]
frame_group.create_dataset("units", data=units)
d = frame_group.create_group("parameters")
for k, par in self.frame.parameters.items():
quantity_to_hdf5(d, k, par)
cart = self.represent_as("cartesian")
quantity_to_hdf5(f, "pos", cart.xyz)
quantity_to_hdf5(f, "vel", cart.v_xyz)
return f
[docs]
@classmethod
def from_hdf5(cls, f):
"""
Load an object from an HDF5 file.
Requires ``h5py``.
Parameters
----------
f : str, :class:`h5py.File`
Either the filename or an open HDF5 file.
"""
if isinstance(f, str):
import h5py
f = h5py.File(f, mode="r")
pos = quantity_from_hdf5(f["pos"])
vel = quantity_from_hdf5(f["vel"])
frame = None
if "frame" in f:
g = f["frame"]
frame_mod = g.attrs["module"]
frame_cls = g.attrs["class"]
frame_units = [u.Unit(x.decode("utf-8")) for x in g["units"]]
if u.dimensionless_unscaled in frame_units:
units = DimensionlessUnitSystem()
else:
units = UnitSystem(*frame_units)
pars = dict()
for k in g["parameters"]:
pars[k] = quantity_from_hdf5(g["parameters/" + k])
exec("from {0} import {1}".format(frame_mod, frame_cls))
frame_cls = eval(frame_cls)
frame = frame_cls(units=units, **pars)
return cls(pos=pos, vel=vel, frame=frame)
# ------------------------------------------------------------------------
# Computed dynamical quantities
#
[docs]
def kinetic_energy(self):
r"""
The kinetic energy *per unit mass*:
.. math::
E_K = \frac{1}{2} \, |\boldsymbol{v}|^2
Returns
-------
E : :class:`~astropy.units.Quantity`
The kinetic energy.
"""
return 0.5 * self.vel.norm() ** 2
[docs]
def potential_energy(self, potential):
r"""
The potential energy *per unit mass*:
.. math::
E_\Phi = \Phi(\boldsymbol{q})
Parameters
----------
potential : `gala.potential.PotentialBase`
The potential object to compute the energy from.
Returns
-------
E : :class:`~astropy.units.Quantity`
The potential energy.
"""
# TODO: check that potential ndim is consistent with here
return potential.energy(self)
[docs]
def energy(self, hamiltonian):
r"""
The total energy *per unit mass* (e.g., kinetic + potential):
Parameters
----------
hamiltonian : `gala.potential.Hamiltonian`, `gala.potential.PotentialBase` instance
The Hamiltonian object to evaluate the energy. If a potential is
passed in, this assumes a static reference frame.
Returns
-------
E : :class:`~astropy.units.Quantity`
The total energy.
"""
from gala.potential import Hamiltonian
hamiltonian = Hamiltonian(hamiltonian)
return hamiltonian(self)
[docs]
def angular_momentum(self):
r"""
Compute the angular momentum for the phase-space positions contained
in this object::
.. math::
\boldsymbol{{L}} = \boldsymbol{{q}} \times \boldsymbol{{p}}
See :ref:`shape-conventions` for more information about the shapes of
input and output objects.
Returns
-------
L : :class:`~astropy.units.Quantity`
Array of angular momentum vectors.
Examples
--------
>>> import numpy as np
>>> import astropy.units as u
>>> pos = np.array([1., 0, 0]) * u.au
>>> vel = np.array([0, 2*np.pi, 0]) * u.au/u.yr
>>> w = PhaseSpacePosition(pos, vel)
>>> w.angular_momentum() # doctest: +FLOAT_CMP
<Quantity [0. ,0. ,6.28318531] AU2 / yr>
"""
cart = self.represent_as(coord.CartesianRepresentation)
return cart.pos.cross(cart.vel).xyz
[docs]
def guiding_radius(self, potential, t=0.0, **root_kwargs):
"""
Compute the guiding-center radius
Parameters
----------
potential : `gala.potential.PotentialBase` subclass instance
The potential to compute the guiding radius in.
t : quantity-like (optional)
Time.
**root_kwargs
Any additional keyword arguments are passed to `~scipy.optimize.root`.
Returns
-------
Rg : :class:`~astropy.units.Quantity`
Guiding-center radius.
"""
R0s = np.atleast_1d(
np.sqrt(self.x**2 + self.y**2).decompose(potential.units).value
)
Lzs = np.atleast_1d(self.angular_momentum()[2].decompose(potential.units).value)
Rgs = _guiding_radius_helper(R0s, Lzs, potential, t, **root_kwargs)
return Rgs.reshape(self.shape) * potential.units["length"]
# ------------------------------------------------------------------------
# Misc. useful methods
#
def _plot_prepare(self, components, units):
"""
Prepare the ``PhaseSpacePosition`` or subclass for passing to a plotting
routine to plot all projections of the object.
"""
# components to plot
if components is None:
components = self.pos.components
n_comps = len(components)
# if units not specified, get units from the components
if units is not None:
if isinstance(units, u.UnitBase):
units = [units] * n_comps # global unit
elif len(units) != n_comps:
raise ValueError(
"You must specify a unit for each axis, or a "
"single unit for all axes."
)
labels = []
x = []
for i, name in enumerate(components):
val = getattr(self, name)
if units is not None:
val = val.to(units[i])
unit = units[i]
else:
unit = val.unit
if val.unit != u.one:
uu = unit.to_string(format="latex_inline")
unit_str = " [{}]".format(uu)
else:
unit_str = ""
# Figure out how to fancy display the component name
if name.startswith("d_"):
dot = True
name = name[2:]
else:
dot = False
if name in _greek_letters:
name = r"\{}".format(name)
if dot:
name = r"\dot{{{}}}".format(name)
labels.append("${}$".format(name) + unit_str)
x.append(val.value)
return x, labels
[docs]
def plot(self, components=None, units=None, auto_aspect=True, **kwargs):
"""
Plot the positions in all projections. This is a wrapper around
`~gala.dynamics.plot_projections` for fast access and quick
visualization. All extra keyword arguments are passed to that function
(the docstring for this function is included here for convenience).
Parameters
----------
components : iterable (optional)
A list of component names (strings) to plot. By default, this is the
Cartesian positions ``['x', 'y', 'z']``. To plot Cartesian
velocities, pass in the velocity component names
``['d_x', 'd_y', 'd_z']``.
units : `~astropy.units.UnitBase`, iterable (optional)
A single unit or list of units to display the components in.
auto_aspect : bool (optional)
Automatically enforce an equal aspect ratio.
relative_to : bool (optional)
Plot the values relative to this value or values.
autolim : bool (optional)
Automatically set the plot limits to be something sensible.
axes : array_like (optional)
Array of matplotlib Axes objects.
subplots_kwargs : dict (optional)
Dictionary of kwargs passed to :func:`~matplotlib.pyplot.subplots`.
labels : iterable (optional)
List or iterable of axis labels as strings. They should correspond to
the dimensions of the input orbit.
plot_function : callable (optional)
The ``matplotlib`` plot function to use. By default, this is
:func:`~matplotlib.pyplot.scatter`, but can also be, e.g.,
:func:`~matplotlib.pyplot.plot`.
**kwargs
All other keyword arguments are passed to the ``plot_function``.
You can pass in any of the usual style kwargs like ``color=...``,
``marker=...``, etc.
Returns
-------
fig : `~matplotlib.Figure`
"""
from gala.tests.optional_deps import HAS_MATPLOTLIB
if not HAS_MATPLOTLIB:
raise ImportError("matplotlib is required for visualization.")
import matplotlib.pyplot as plt
if components is None:
components = self.pos.components
x, labels = self._plot_prepare(components=components, units=units)
kwargs.setdefault("plot_function", plt.scatter)
if kwargs["plot_function"] in [plt.plot, plt.scatter]:
kwargs.setdefault("marker", ".")
kwargs.setdefault("labels", labels)
kwargs.setdefault("plot_function", plt.scatter)
kwargs.setdefault("autolim", False)
fig = plot_projections(x, **kwargs)
if (
self.pos.get_name() == "cartesian"
and all([not c.startswith("d_") for c in components])
and auto_aspect
):
for ax in fig.axes:
ax.set(aspect="equal", adjustable="datalim")
return fig
# ------------------------------------------------------------------------
# Display
#
def __repr__(self):
return "<{} {}, dim={}, shape={}>".format(
self.__class__.__name__, self.pos.get_name(), self.ndim, self.pos.shape
)
def __str__(self):
return "pos={}\nvel={}".format(self.pos, self.vel)
# ------------------------------------------------------------------------
# Shape and size
#
@property
def shape(self):
"""
This is *not* the shape of the position or velocity arrays. That is
accessed by doing, e.g., ``obj.x.shape``.
"""
return self.pos.shape
[docs]
def reshape(self, new_shape):
"""
Reshape the underlying position and velocity arrays.
"""
return self.__class__(
pos=self.pos.reshape(new_shape),
vel=self.vel.reshape(new_shape),
frame=self.frame,
)
def _guiding_radius_rootfunc(R, Lz, potential, t):
dPhi_dR = potential.c_instance.d_dr(
np.array([[R[0], 0.0, 0.0]]), potential.G, t=np.array([t])
)
vc = np.sqrt(R * np.abs(dPhi_dR))
return Lz - R * vc
def _guiding_radius_helper(R0s, Lzs, potential, t, **root_kwargs):
from scipy.optimize import root
root_kwargs.setdefault("options", dict(xtol=1e-5))
root_kwargs.setdefault("method", "hybr")
Rgs = np.zeros_like(R0s)
for i, (R0, Lz) in enumerate(zip(R0s, Lzs)):
res = root(
_guiding_radius_rootfunc, R0, args=(np.abs(Lz), potential, t), **root_kwargs
)
if res.success:
Rgs[i] = res.x[0]
else:
Rgs[i] = np.nan
return Rgs