import numpy as np
import pandas as pd
from env import OrbitZoo

def acc(r_input, v_input, thrust, angle, r_e, rm, GM, m, C_d, A, F_t):
    r2 = np.dot(r_input, r_input)
    r = np.sqrt(r2)
    v2 = np.dot(v_input, v_input)
    v = np.sqrt(v2)
    t1 = F_t * thrust
    t2 = np.array([t1 * np.cos(angle), t1 * np.sin(angle)])
    r_unit = r_input / r
    v_unit = v_input / v
    h = r - r_e
    perturbation = 0.05 * np.random.randn()
    rho = rm / ((7.8974E-24 + 8.89106E-31 * h) * (141.89 + 0.00299 * h) ** 11.388)
    drag = rho * v2 * C_d * A * v_unit
    force = - (GM * m / r2) * r_unit
    force -= drag + perturbation * drag
    force += t2
    a = force / m
    return a

def yoshida(r, v, dt, thrust, angle, r_e, rm, GM, m, C_d, A, F_t):
    c1 = 0.6756035959798288170238
    c2 = -0.1756035959798288170238
    c3 = -0.1756035959798288170238
    c4 = 0.6756035959798288170238
    d1 = 1.3512071919596576340476
    d2 = -1.7024143839193152680953
    d3 = 1.3512071919596576340476
    r1 = r + c1 * dt * v
    v1 = v + d1 * dt * acc(r1, v, thrust, angle, r_e, rm, GM, m, C_d, A, F_t)
    r2 = r1 + c2 * dt * v1
    v2 = v1 + d2 * dt * acc(r2, v1, thrust, angle, r_e, rm, GM, m, C_d, A, F_t)
    r3 = r2 + c3 * dt * v2
    v3 = v2 + d3 * dt * acc(r3, v2, thrust, angle, r_e, rm, GM, m, C_d, A, F_t)
    r = r3 + c4 * v3 * dt
    v = v3
    return r, v

def calculate_physics(r, v, thrust, angle, F_t, r_e, GM, m, C_d, A, steps, dt, rm):
    ndt = dt / steps
    for _ in range(steps):
        r, v = yoshida(r, v, ndt, thrust, angle, r_e, rm, GM, m, C_d, A, F_t)
    return r, v

h = 5.5E5                # Height of satellite 550 km in meters
r_e = 6.371E6            # Radius of earth in meters
r_s = h + r_e            # Radius from center of earth
mp = 75.0                # Mass of propellant
mf = 25.0                # Mass of satellite w/o propellant
m = mp + mf              # Mass of satellite
dt = 1.0                 # Delta t
GM = 3.986004418E14      # Earth's gravitational parameter
C_d = 2.123              # Drag coefficient
A = 1.0                  # Surface area normal to velocity
F_t = 0.04               # Force of thrust
steps = 1                # Step per dt
threshold = 1            # Threshold to end episode
rho_multiplier = 10000   # Rho is multiplied by this amount
step_fuel = 125          # Number of steps for fuel to run out, when at full thrust
orbit_v = np.sqrt(GM / r_s)
# Some state vectors
r = np.zeros(2)
v = np.zeros(2)
theta = 0
thrust = 0.0
r = np.array([r_s, 0])
v = np.array([0, orbit_v])
step_fuel_used = 0

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 6920995.839264945, "y": 7588.996868126882, "z": 0.0, "x_dot": -8.32146925323441, "y_dot": 7588.993777978686, "z_dot": 0.0},
             "initial_state_uncertainty": {"x": 1e-15, "y": 1e-15, "z": 1e-15, "x_dot": 1e-15, "y_dot": 1e-15, "z_dot": 1e-15},
             "initial_mass": 25.0,
             "fuel_mass": 75.0,
             "isp": 0.0067,
             "radius": 16.8,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 5,
                "gamma": 0.99,
                "clip": 0.03,
                "action_std_init": 0.01,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [0.04 / 50, 2 * np.pi / 6],
             }},
        ],
        "delta_t": 1.0,
        "forces": {
            "gravity_model":  "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 0.5,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    # "agent": 2.123,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "zoom": 1.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 50,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 5000,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 70,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                {"a": 550.0e3, "e": 0.00001, "i": 0.00001, "pa": 0.0, "raan": 0.0, "color": (0, 255, 0)},
            ],
        }
    }

env = OrbitZoo(params)
env.reset()
satellite = env.satellites[0]

herrera_positions = []
orbitzoo_positions = []
herrera_velocities = []
orbitzoo_velocities = []

data = []

for t in range(800):
    # Herrera
    thrust, theta = 0, 0
    step_fuel_used += thrust
    m = mf + (1 - step_fuel_used / step_fuel) * mp
    r, v = calculate_physics(r, v, thrust, theta, F_t, r_e, GM, m, C_d, A, steps, dt, rho_multiplier)

    # OrbitZoo
    env.step({'agent': None})
    orbitzoo_r = satellite.get_cartesian_position()
    orbitzoo_v = satellite.get_cartesian_velocity()

    data.append({
        't': t,
        'herrera_rx': r[0],
        'herrera_ry': r[1],
        'herrera_vx': v[0],
        'herrera_vy': v[1],
        'orbitzoo_rx': orbitzoo_r[0],
        'orbitzoo_ry': orbitzoo_r[1],
        'orbitzoo_vx': orbitzoo_v[0],
        'orbitzoo_vy': orbitzoo_v[1],
    })

data = pd.DataFrame(data)
data.to_csv('herrera_orbitzoo.csv')