import numpy as np
from env import OrbitZoo
from bodies import Satellite, Body
from rl_algorithms.ppo import PPO
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU, ATTITUDE
from math import radians
import traceback
import pandas as pd

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType, OrbitType
from org.hipparchus.geometry.euclidean.threed import Vector3D
from org.orekit.forces.gravity import NewtonianAttraction
from org.orekit.propagation.numerical import NumericalPropagator
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.hipparchus.linear import Array2DRowRealMatrix
from org.orekit.propagation import StateCovariance, StateCovarianceMatrixProvider
from orekit import JArray_double

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.1, "y": 0.1, "z": 0.1, "x_dot": 0.1, "y_dot": 0.1, "z_dot": 0.1},
             "initial_mass": 250.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 10.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 10,
                "gamma": 0.95,
                "clip": 0.5,
                "action_std_init": 0.005,
                "state_dim_actor": 14,
                "state_dim_critic": 14,
                "action_space": [5.0, np.pi, 2*np.pi, 1.0],
             }},
        ],
        "drifters": [
            {"name": "drifter",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": 5320.577430007447, "y_dot": -4465.13261069736, "z_dot": -526.2965261179082},
             "initial_state_uncertainty": {"x": 0.1, "y": 0.1, "z": 0.1, "x_dot": 0.1, "y_dot": 0.1, "z_dot": 0.1},
             "initial_mass": 250.0,
             "radius": 5.0,
             "save_steps_info": False,},
        ],
        "delta_t": 1800.0,
        "forces": {
            "gravity_model": "Newtonian",
            "third_bodies": {
                "active": True,
                "bodies": ["SUN", "MOON", "JUPITER"],
            },
            "solar_radiation_pressure": {
                "active": True,
                "reflection_coefficients": {
                    "agent": 1.0,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    "agent": 1.0,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": True,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 0, 0),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": True,
                "show_thrust": True,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
            },
        }
    }

class Drifter(Body):

    def reset(self, seed=None):
        super().reset(seed)

        tolerances = NumericalPropagator.tolerances(60.0, self.current_state.getOrbit(), self.current_state.getOrbit().getType())
        integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
        integrator.setInitialStepSize(10.0)
        propagator = NumericalPropagator(integrator)
        propagator.setInitialState(self.current_state)
        propagator.setMu(MU)
        propagator.setOrbitType(self.orbit_type)
        propagator.setAttitudeProvider(ATTITUDE)
        propagator.addForceModel(NewtonianAttraction(MU))
        covariance_matrix = self.dist.covariance_matrix.detach().numpy().tolist()
        matrix = Array2DRowRealMatrix(6, 6)
        for i in range(6):
            matrix.setRow(i, covariance_matrix[i])
        initial_covariance = StateCovariance(matrix, self.current_date, INERTIAL_FRAME, OrbitType.CARTESIAN, PositionAngleType.MEAN)
        harvester = propagator.setupMatricesComputation("stm", None, None)
        self.simulation_covariance_provider = StateCovarianceMatrixProvider("covariance", "stm", harvester, initial_covariance)
        propagator.addAdditionalStateProvider(self.simulation_covariance_provider)
        self.simulation_propagator = propagator

        # simulation_propagator = self.propagator
        # simulation_propagator.removeForceModels()
        # simulation_propagator.addForceModel(NewtonianAttraction(MU))
        # self.simulation_propagator = simulation_propagator

    def step_simulation(self, seconds = None):
        if not seconds:
            seconds = self.delta_t
        self.simulated_current_state = self.simulation_propagator.propagate(self.current_date.shiftedBy(seconds))
        self.simulated_current_date = self.current_state.getDate()
        return self.get_state()

    def propagate_back(self, seconds = None):
        if seconds is None:
            seconds = self.delta_t
        self.current_state = self.propagator.propagate(self.current_date.shiftedBy(-seconds))
        self.current_date = self.current_state.getDate()
        return self.get_state()
    
    def get_simulated_covariance_matrix(self, state = None):
        """
        Get the covariance matrix relative to a state. If no state is provided, it corresponds to the current state of the body.
        """
        if self.current_date == self.initial_date:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        if not state:
            state = self.simulated_current_state
        try:
            covariance_matrix = self.simulation_covariance_provider.getStateCovariance(state).getMatrix()
        except:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        covariance_matrix = np.array([covariance_matrix.getRow(i) for i in range(6)], dtype=float)
        return covariance_matrix.tolist()

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)

        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def reset(self, seed=None):
        super().reset(seed)

        tolerances = NumericalPropagator.tolerances(60.0, self.current_state.getOrbit(), self.current_state.getOrbit().getType())
        integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
        integrator.setInitialStepSize(10.0)
        propagator = NumericalPropagator(integrator)
        propagator.setInitialState(self.current_state)
        propagator.setMu(MU)
        propagator.setOrbitType(self.orbit_type)
        propagator.setAttitudeProvider(ATTITUDE)
        propagator.addForceModel(NewtonianAttraction(MU))
        covariance_matrix = self.dist.covariance_matrix.detach().numpy().tolist()
        matrix = Array2DRowRealMatrix(6, 6)
        for i in range(6):
            matrix.setRow(i, covariance_matrix[i])
        initial_covariance = StateCovariance(matrix, self.current_date, INERTIAL_FRAME, OrbitType.CARTESIAN, PositionAngleType.MEAN)
        harvester = propagator.setupMatricesComputation("stm", None, None)
        self.simulation_covariance_provider = StateCovarianceMatrixProvider("covariance", "stm", harvester, initial_covariance)
        propagator.addAdditionalStateProvider(self.simulation_covariance_provider)
        self.simulation_propagator = propagator

        # simulation_propagator = self.propagator
        # simulation_propagator.removeForceModels()
        # simulation_propagator.addForceModel(NewtonianAttraction(MU))
        # self.simulation_propagator = simulation_propagator

    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
    
    def propagate_back(self, seconds = None):
        if seconds is None:
            seconds = self.delta_t
        self.current_state = self.propagator.propagate(self.current_date.shiftedBy(-seconds))
        self.current_date = self.current_state.getDate()
        return self.get_state()
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

    def step_simulation(self, seconds = None):
        if not seconds:
            seconds = self.delta_t
        self.simulated_current_state = self.simulation_propagator.propagate(self.current_date.shiftedBy(seconds))
        self.simulated_current_date = self.current_state.getDate()
        return self.get_state()
    
    def get_simulated_covariance_matrix(self, state = None):
        """
        Get the covariance matrix relative to a state. If no state is provided, it corresponds to the current state of the body.
        """
        if self.current_date == self.initial_date:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        if not state:
            state = self.simulated_current_state
        try:
            covariance_matrix = self.simulation_covariance_provider.getStateCovariance(state).getMatrix()
        except:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        covariance_matrix = np.array([covariance_matrix.getRow(i) for i in range(6)], dtype=float)
        return covariance_matrix.tolist()

class ColAvoidanceEnv(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        target = [2_000_000, 0.01, 5.0, 20.0, 20.0, 10.0]
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target = [orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy(), orbit.getHx(), orbit.getHy()]
        print(self.target)

    def create_bodies(self, params):
        self.drifters = [Drifter(body_params) for body_params in params["drifters"]] if "drifters" in params else []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None, t = 1000):

        self.t = t
        
        # reset all bodies, propagate back 4 days and reset covariance matrices
        for body in self.drifters + self.satellites:
            body.reset(seed)
            covariance = body.get_covariance_matrix()
            body.propagate_back(float(60 * 60 * 24 * self.days_before))
            body.set_covariance_matrix(covariance)

        agent = self.satellites[0]
        drifter = self.drifters[0]
        poc, tca, closest_approach = env.predicted_poc()

        observations = {'agent': agent.get_equinoctial_position() + drifter.get_equinoctial_position() + [agent.get_fuel()] + [poc]}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos, tca, closest_approach

    def step(self, actions=None, last_observations=None):

        for drifter in self.drifters:
            drifter.step()

        satellite = self.satellites[0]

        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        satellite = self.satellites[0]
        action = scaled_actions[str(satellite)]

        if action[3] < 0.5:
            scaled_actions = {str(satellite): np.array([0.0,0.0,0.0]) for satellite in env.satellites}
        else:
            # print("action")
            scaled_actions = {str(satellite): action[:-1] for satellite in env.satellites}

        observations = {satellite.name: satellite.step(scaled_actions[satellite.name], 10.0)}

        self.t -= 1

        agent = self.satellites[0]
        drifter = self.drifters[0]
        poc, tca, closest_approach = env.predicted_poc()
        observations = {'agent': agent.get_equinoctial_position() + drifter.get_equinoctial_position() + [agent.get_fuel()] + [poc]}

        rewards, terminations = self.rewards(observations, last_observations, clipped_actions)

        # observations = self.normalize_states(states)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos, tca, closest_approach

    def rewards(self, observations, last_observations, actions):

        agent_name = self.satellites[0].name
        state = observations[agent_name]
        last_state = last_observations[agent_name]
        target = self.target

        last_poc = last_state[13]

        # if last POC was low, penalize the usage of thrust
        if last_poc < 1e-6:
            decision = actions[agent_name][3]
            reward = -10 if decision > 0 else 0
        # if last POC was high, penalize if POC is still high and penalize the distance from the nominal orbit
        else:
            elements_diff = np.abs(np.array(state[:5]) - np.array(target))
            elements_diff[0] /= target[0]
            elements_weights = np.array([1000, 1, 1, 10, 10])
            distance_penalty = elements_weights.T @ elements_diff
            poc = state[13]
            poc_penalty = 10 if poc > 1e-6 else 0
            reward = - (distance_penalty + poc_penalty)

        # print(elements_diff)
        # print(f'- ({distance_penalty} + {poc_penalty}) = {reward}')

        return {agent_name: reward}, {agent_name: False}

    def predicted_poc_old(self):

        agent = env.satellites[0]
        drifter = env.drifters[0]

        agent.simulated_current_state = agent.current_state
        drifter.simulated_current_state = drifter.current_state

        agent.simulation_propagator.setInitialState(agent.simulated_current_state)
        drifter.simulation_propagator.setInitialState(drifter.simulated_current_state)

        agent_simulated_state = agent.simulated_current_state
        drifter_simulated_state = drifter.simulated_current_state

        # calculate time of closest approach (TCA)
        closest_step = 0
        closest_approach = 1e15
        for t in range(self.t):
            agent_pos = agent_simulated_state.getPVCoordinates().getPosition()
            drifter_pos = drifter_simulated_state.getPVCoordinates().getPosition()
            agent_vector = Vector3D([agent_pos.getX(), agent_pos.getY(), agent_pos.getZ()])
            drifter_vector = Vector3D([drifter_pos.getX(), drifter_pos.getY(), drifter_pos.getZ()])
            distance = Vector3D.distance(agent_vector, drifter_vector)
            if distance < closest_approach:
                closest_approach = distance
                closest_step = t
            agent_simulated_state = agent.simulation_propagator.propagate(agent_simulated_state.getDate().shiftedBy(agent.delta_t))
            drifter_simulated_state = drifter.simulation_propagator.propagate(drifter_simulated_state.getDate().shiftedBy(drifter.delta_t))
        tca = closest_step * agent.delta_t

        agent.simulation_propagator.setInitialState(agent.simulated_current_state)
        drifter.simulation_propagator.setInitialState(drifter.simulated_current_state)

        # put bodies on tca with simulation propagator
        agent.step_simulation(tca)
        drifter.step_simulation(tca)

        agent_pos = agent.simulated_current_state.getPVCoordinates().getPosition()
        drifter_pos = drifter.simulated_current_state.getPVCoordinates().getPosition()
        agent_vel = agent.simulated_current_state.getPVCoordinates().getVelocity()
        drifter_vel = drifter.simulated_current_state.getPVCoordinates().getVelocity()
        relative_pos = np.array([agent_pos.getX(), agent_pos.getY(), agent_pos.getZ()]) - np.array([drifter_pos.getX(), drifter_pos.getY(), drifter_pos.getZ()])
        relative_vel = np.array([agent_vel.getX(), agent_vel.getY(), agent_vel.getZ()]) - np.array([drifter_vel.getX(), drifter_vel.getY(), drifter_vel.getZ()])
        tca_linear = - float((relative_vel.T @ relative_pos) / (relative_vel.T @ relative_vel))
        chaser_pos = agent.simulated_current_state.getPVCoordinates().getPosition()
        target_pos = drifter.simulated_current_state.getPVCoordinates().getPosition()
        chaser_vel = agent.simulated_current_state.getPVCoordinates().getVelocity()
        target_vel = drifter.simulated_current_state.getPVCoordinates().getVelocity()
        chaser_pos = np.array([chaser_pos.getX(), chaser_pos.getY(), chaser_pos.getZ()])
        target_pos = np.array([target_pos.getX(), target_pos.getY(), target_pos.getZ()])
        chaser_vel = np.array([chaser_vel.getX(), chaser_vel.getY(), chaser_vel.getZ()])
        target_vel = np.array([target_vel.getX(), target_vel.getY(), target_vel.getZ()])
        chaser_cov = np.array(agent.get_simulated_covariance_matrix())[:3, :3]
        target_cov = np.array(drifter.get_simulated_covariance_matrix())[:3, :3]

        # calculate probability of collision (poc)
        poc = Body.poc_rederivation_simulation(chaser_pos, chaser_vel, chaser_cov, target_pos, target_vel, target_cov, tca_linear, agent.radius, drifter.radius)

        # put bodies back at current state
        #agent.step_simulation(-tca)
        #drifter.step_simulation(-tca)

        if poc is None:
            poc = 0.0

        return poc, tca, closest_approach
    
    def predicted_poc(self):

        agent = env.satellites[0]
        drifter = env.drifters[0]

        agent_state = agent.get_cartesian_position() + agent.get_cartesian_velocity()
        drifter_state = drifter.get_cartesian_position() + drifter.get_cartesian_velocity()

        # create copies of current bodies (with no uncertainty)
        agent_params = {"name": "agent",
                       "initial_state": {"x": agent_state[0], "y": agent_state[1], "z": agent_state[2], "x_dot": agent_state[3], "y_dot": agent_state[4], "z_dot": agent_state[5]},
                       "initial_state_uncertainty": {"x": 1e-10, "y": 1e-10, "z": 1e-10, "x_dot": 1e-10, "y_dot": 1e-10, "z_dot": 1e-10},
                       "initial_mass": agent.get_mass(),
                       "radius": agent.radius,
                       "save_steps_info": False,
                       "initial_date": agent.initial_date}
        
        drifter_params = {"name": "drifter",
                          "initial_state": {"x": drifter_state[0], "y": drifter_state[1], "z": drifter_state[2], "x_dot": drifter_state[3], "y_dot": drifter_state[4], "z_dot": drifter_state[5]},
                          "initial_state_uncertainty": {"x": 1e-10, "y": 1e-10, "z": 1e-10, "x_dot": 1e-10, "y_dot": 1e-10, "z_dot": 1e-10},
                          "initial_mass": drifter.get_mass(),
                          "radius": drifter.radius,
                          "save_steps_info": False,
                          "initial_date": drifter.initial_date}

        # simplest gravity model (NewtonianAttraction)
        gravity_model = NewtonianAttraction(MU)

        agent_sim = Body(agent_params)
        drifter_sim = Body(drifter_params)
        agent_sim.delta_t = agent.delta_t
        agent_sim.forces = [gravity_model]
        drifter_sim.delta_t = drifter.delta_t
        drifter_sim.forces = [gravity_model]

        agent_sim.reset()
        drifter_sim.reset()

        # in order to get the correct covariance propagation, set the original covariance matrix
        # agent_sim.set_covariance_matrix(agent.get_covariance_matrix())
        drifter_sim.set_covariance_matrix(drifter.get_covariance_matrix())

        # calculate time of closest approach (TCA)
        closest_step = 0
        closest_approach = 1e15
        for t in range(self.t):
            distance = Body.get_distance(agent_sim, drifter_sim)
            if distance < closest_approach:
                closest_approach = distance
                closest_step = t
            agent_sim.step()
            drifter_sim.step()
        tca = closest_step * agent.delta_t

        agent_sim.reset()
        drifter_sim.reset()
        # agent_sim.set_covariance_matrix(agent.get_covariance_matrix())
        drifter_sim.set_covariance_matrix(drifter.get_covariance_matrix())
        agent_sim.step(tca)
        drifter_sim.step(tca)

        # calculate probability of collision (poc) at tca
        poc = Body.poc_rederivation(agent_sim, drifter_sim)

        if poc is None:
            poc = 0.0

        return poc, tca, closest_approach

env = ColAvoidanceEnv(params)

env.days_before = 4

time_step = 1
action_std_decay_freq = 1000
action_std_decay_rate = 0.05
update_freq = 256
min_action_std = 0.05
episodes = 10000
steps_per_episode = int(float(60 * 60 * 24 * env.days_before) / params["delta_t"] + 10)
low = np.array([0.0, 0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

print(f'steps_per_episode: {steps_per_episode}')

# env.reset()
# for t in range(steps_per_episode):
#     if t > 0:
#         poc, tca, distance = env.predicted_poc()
#         print(f'step: {t}, tca: {t + (tca / params["delta_t"])}, poc: {poc:.20f}')
#     env.step({'agent': np.array([-1,-1,-1,-1])})
#     # env.render()

# load model
env.satellites[0].ppo.load(".\\missions\\col_avoidance\\model_col_avoidance.pth")
data = []
start_step = time_step
current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}

observations, _, _, _ = env.reset(42, t=steps_per_episode)
#print("reset")
for t in range(1, steps_per_episode + 1):
    try:

        if not env.satellites[0].has_fuel():
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] - 1_000 for satellite in env.satellites}
            break

        # print(observations)

        # select actions with policies
        actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}
        #print("action")

        # apply step
        # actions = {'agent': np.array([-1.0, -1.0, -1.0, -1.0])} # do nothing
        observations, rewards, terminations, _, _, tca, closest_approach = env.step(actions, observations)

        #print(t + (tca / params["delta_t"]))
        #print(closest_approach)

        current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

        # print(f'{observations}\n{actions}')

        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}
        state = observations['agent']
        action = scaled_actions['agent']
        pos = env.satellites[0].get_cartesian_position()
        elements = env.satellites[0].get_equinoctial_position()
        tca_steps = (tca / params["delta_t"])
        data.append({
                    'step': t,
                    "M": action[0],
                    "theta": action[1],
                    "phi": action[2],
                    "delta": action[3],
                    "x": pos[0],
                    "y": pos[1],
                    "z": pos[2],
                    "a": elements[0],
                    "ex": elements[1],
                    "ey": elements[2],
                    "hx": elements[3],
                    "hy": elements[4],
                    "poc": state[13],
                    "tca": tca_steps if tca_steps != 0 else 202,
                    "ca": closest_approach
        })

        time_step += 1

        if terminations[env.agents[0]]:
            break
        
    except Exception as e:
        # if there's a propagation error, remove last values from each satellite agent buffer
        for satellite in env.satellites:
            satellite.ppo.buffer.states.pop()
            satellite.ppo.buffer.actions.pop()
            satellite.ppo.buffer.logprobs.pop()
        traceback.print_exc()
        break

    if params["interface"]["show"]:
        env.render()

# show scores at the end of episode
for satellite in env.satellites:
    score = current_ep_rewards[str(satellite)]
    print(f'Score: {score}, Fuel: {satellite.get_fuel()}, Std: {satellite.ppo.action_std}')

df = pd.DataFrame(data)
df.to_csv("collision_avoidance_4days.csv")