import astropy.units as u
import numpy as np
import gala.dynamics as gd
import gala.potential as gp
from gala.units import galactic

pot = gp.HernquistPotential(m=1E10*u.Msun, c=1.*u.kpc,
                            units=galactic)
frame = gp.StaticFrame(units=galactic)
H = gp.Hamiltonian(potential=pot, frame=frame)
w0 = gd.PhaseSpacePosition(pos=[5.,0,0]*u.kpc,
                           vel=[0,0,50.]*u.km/u.s)
orbit = H.integrate_orbit(w0, dt=0.5, n_steps=1000)

rotation_axis = np.array([8.2, -1.44, 3.25])
rotation_axis /= np.linalg.norm(rotation_axis) # make a unit vector
frame_freq = 42. * u.km/u.s/u.kpc
rot_frame = gp.ConstantRotatingFrame(Omega=frame_freq * rotation_axis,
                                     units=galactic)
orbit_to_rot = orbit.to_frame(rot_frame)

fig1 = orbit.plot(marker='')
fig1.suptitle("Static frame", fontsize=20, y=0.96)
fig1.subplots_adjust(top=0.92)
fig1.tight_layout()

fig2 = orbit_to_rot.plot(marker='')
fig2.suptitle("Rotating frame", fontsize=20, y=0.96)
fig2.subplots_adjust(top=0.92)
fig2.tight_layout()