import numpy as np
from env import OrbitZoo
from bodies import Satellite, Body
from rl_algorithms.ppo import PPO
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU, ATTITUDE
from math import radians
import traceback

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType, OrbitType
from org.hipparchus.geometry.euclidean.threed import Vector3D
from org.orekit.forces.gravity import NewtonianAttraction
from org.orekit.propagation.numerical import NumericalPropagator
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.hipparchus.linear import Array2DRowRealMatrix
from org.orekit.propagation import StateCovariance, StateCovarianceMatrixProvider
from orekit import JArray_double

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.1, "y": 0.1, "z": 0.1, "x_dot": 0.1, "y_dot": 0.1, "z_dot": 0.1},
             "initial_mass": 250.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 10.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 5,
                "gamma": 0.95,
                "clip": 0.5,
                "action_std_init": 0.1,
                "state_dim_actor": 14,
                "state_dim_critic": 14,
                "action_space": [5.0, np.pi, 2*np.pi, 1.0],
             }},
        ],
        "drifters": [
            {"name": "drifter",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": 5320.577430007447, "y_dot": -4465.13261069736, "z_dot": -526.2965261179082},
             "initial_state_uncertainty": {"x": 0.1, "y": 0.1, "z": 0.1, "x_dot": 0.1, "y_dot": 0.1, "z_dot": 0.1},
             "initial_mass": 250.0,
             "radius": 5.0,
             "save_steps_info": False,},
        ],
        "delta_t": 1800.0,
        "forces": {
            "gravity_model": "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON", "JUPITER"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 2.0,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    "agent": 10.0,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": True,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 0, 0),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": True,
                "show_thrust": True,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
            },
        }
    }

class Drifter(Body):

    def reset(self, seed=None):
        super().reset(seed)

        tolerances = NumericalPropagator.tolerances(60.0, self.current_state.getOrbit(), self.current_state.getOrbit().getType())
        integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
        integrator.setInitialStepSize(10.0)
        propagator = NumericalPropagator(integrator)
        propagator.setInitialState(self.current_state)
        propagator.setMu(MU)
        propagator.setOrbitType(self.orbit_type)
        propagator.setAttitudeProvider(ATTITUDE)
        propagator.addForceModel(NewtonianAttraction(MU))
        covariance_matrix = self.dist.covariance_matrix.detach().numpy().tolist()
        matrix = Array2DRowRealMatrix(6, 6)
        for i in range(6):
            matrix.setRow(i, covariance_matrix[i])
        initial_covariance = StateCovariance(matrix, self.current_date, INERTIAL_FRAME, OrbitType.CARTESIAN, PositionAngleType.MEAN)
        harvester = propagator.setupMatricesComputation("stm", None, None)
        self.simulation_covariance_provider = StateCovarianceMatrixProvider("covariance", "stm", harvester, initial_covariance)
        propagator.addAdditionalStateProvider(self.simulation_covariance_provider)
        self.simulation_propagator = propagator

        # simulation_propagator = self.propagator
        # simulation_propagator.removeForceModels()
        # simulation_propagator.addForceModel(NewtonianAttraction(MU))
        # self.simulation_propagator = simulation_propagator

    def step_simulation(self, seconds = None):
        if not seconds:
            seconds = self.delta_t
        self.simulated_current_state = self.simulation_propagator.propagate(self.current_date.shiftedBy(seconds))
        self.simulated_current_date = self.current_state.getDate()
        return self.get_state()

    def propagate_back(self, seconds = None):
        if seconds is None:
            seconds = self.delta_t
        self.current_state = self.propagator.propagate(self.current_date.shiftedBy(-seconds))
        self.current_date = self.current_state.getDate()
        return self.get_state()
    
    def get_simulated_covariance_matrix(self, state = None):
        """
        Get the covariance matrix relative to a state. If no state is provided, it corresponds to the current state of the body.
        """
        if self.current_date == self.initial_date:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        if not state:
            state = self.simulated_current_state
        try:
            covariance_matrix = self.simulation_covariance_provider.getStateCovariance(state).getMatrix()
        except:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        covariance_matrix = np.array([covariance_matrix.getRow(i) for i in range(6)], dtype=float)
        return covariance_matrix.tolist()

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)

        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def reset(self, seed=None):
        super().reset(seed)

        tolerances = NumericalPropagator.tolerances(60.0, self.current_state.getOrbit(), self.current_state.getOrbit().getType())
        integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
        integrator.setInitialStepSize(10.0)
        propagator = NumericalPropagator(integrator)
        propagator.setInitialState(self.current_state)
        propagator.setMu(MU)
        propagator.setOrbitType(self.orbit_type)
        propagator.setAttitudeProvider(ATTITUDE)
        propagator.addForceModel(NewtonianAttraction(MU))
        covariance_matrix = self.dist.covariance_matrix.detach().numpy().tolist()
        matrix = Array2DRowRealMatrix(6, 6)
        for i in range(6):
            matrix.setRow(i, covariance_matrix[i])
        initial_covariance = StateCovariance(matrix, self.current_date, INERTIAL_FRAME, OrbitType.CARTESIAN, PositionAngleType.MEAN)
        harvester = propagator.setupMatricesComputation("stm", None, None)
        self.simulation_covariance_provider = StateCovarianceMatrixProvider("covariance", "stm", harvester, initial_covariance)
        propagator.addAdditionalStateProvider(self.simulation_covariance_provider)
        self.simulation_propagator = propagator

        # simulation_propagator = self.propagator
        # simulation_propagator.removeForceModels()
        # simulation_propagator.addForceModel(NewtonianAttraction(MU))
        # self.simulation_propagator = simulation_propagator

    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
    
    def propagate_back(self, seconds = None):
        if seconds is None:
            seconds = self.delta_t
        self.current_state = self.propagator.propagate(self.current_date.shiftedBy(-seconds))
        self.current_date = self.current_state.getDate()
        return self.get_state()
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

    def step_simulation(self, seconds = None):
        if not seconds:
            seconds = self.delta_t
        self.simulated_current_state = self.simulation_propagator.propagate(self.current_date.shiftedBy(seconds))
        self.simulated_current_date = self.current_state.getDate()
        return self.get_state()
    
    def get_simulated_covariance_matrix(self, state = None):
        """
        Get the covariance matrix relative to a state. If no state is provided, it corresponds to the current state of the body.
        """
        if self.current_date == self.initial_date:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        if not state:
            state = self.simulated_current_state
        try:
            covariance_matrix = self.simulation_covariance_provider.getStateCovariance(state).getMatrix()
        except:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        covariance_matrix = np.array([covariance_matrix.getRow(i) for i in range(6)], dtype=float)
        return covariance_matrix.tolist()

class ColAvoidanceEnv(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        target = [2_000_000, 0.01, 5.0, 20.0, 20.0, 10.0]
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target = [orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy(), orbit.getHx(), orbit.getHy()]

    def create_bodies(self, params):
        self.drifters = [Drifter(body_params) for body_params in params["drifters"]] if "drifters" in params else []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None, t = 1000):

        self.t = t
        
        # reset all bodies, propagate back 4 days and reset covariance matrices
        for body in self.drifters + self.satellites:
            body.reset(seed)
            covariance = body.get_covariance_matrix()
            body.propagate_back(float(60 * 60 * 24 * 2))
            body.set_covariance_matrix(covariance)

        agent = self.satellites[0]
        drifter = self.drifters[0]
        poc, tca, closest_approach = env.predicted_poc()

        observations = {'agent': agent.get_equinoctial_position() + drifter.get_equinoctial_position() + [agent.get_fuel()] + [poc]}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None, last_observations=None):

        for drifter in self.drifters:
            drifter.step()

        satellite = self.satellites[0]

        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        satellite = self.satellites[0]
        action = scaled_actions[str(satellite)]

        if action[3] < 0.5:
            scaled_actions = {str(satellite): np.array([0.0,0.0,0.0]) for satellite in env.satellites}
        else:
            scaled_actions = {str(satellite): action[:-1] for satellite in env.satellites}

        observations = {satellite.name: satellite.step(scaled_actions[satellite.name], 10.0)}

        self.t -= 1

        agent = self.satellites[0]
        drifter = self.drifters[0]
        poc, tca, closest_approach = env.predicted_poc()
        observations = {'agent': agent.get_equinoctial_position() + drifter.get_equinoctial_position() + [agent.get_fuel()] + [poc]}

        rewards, terminations = self.rewards(observations, last_observations, clipped_actions)

        # observations = self.normalize_states(states)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos

    def rewards(self, observations, last_observations, actions):

        agent_name = self.satellites[0].name
        state = observations[agent_name]
        last_state = last_observations[agent_name]
        target = self.target

        last_poc = last_state[13]

        # if last POC was low, penalize the usage of thrust
        if last_poc < 1e-6:
            decision = actions[agent_name][3]
            reward = -100 if decision > 0 else 0
        # if last POC was high, penalize if POC is still high and penalize the distance from the nominal orbit
        else:
            elements_diff = np.abs(np.array(state[:5]) - np.array(target))
            elements_diff[0] /= target[0]
            elements_weights = np.array([1000, 1, 1, 10, 10])
            distance_penalty = elements_weights.T @ elements_diff
            poc = state[13]
            poc_penalty = 10 if poc > 1e-6 else 0
            reward = - (distance_penalty + poc_penalty)

        # print(elements_diff)
        # print(f'- ({distance_penalty} + {poc_penalty}) = {reward}')

        return {agent_name: reward}, {agent_name: False}

    def predicted_poc_old(self):

        agent = env.satellites[0]
        drifter = env.drifters[0]

        agent.simulated_current_state = agent.current_state
        drifter.simulated_current_state = drifter.current_state

        agent.simulation_propagator.setInitialState(agent.simulated_current_state)
        drifter.simulation_propagator.setInitialState(drifter.simulated_current_state)

        agent_simulated_state = agent.simulated_current_state
        drifter_simulated_state = drifter.simulated_current_state

        # calculate time of closest approach (TCA)
        closest_step = 0
        closest_approach = 1e15
        for t in range(self.t):
            agent_pos = agent_simulated_state.getPVCoordinates().getPosition()
            drifter_pos = drifter_simulated_state.getPVCoordinates().getPosition()
            agent_vector = Vector3D([agent_pos.getX(), agent_pos.getY(), agent_pos.getZ()])
            drifter_vector = Vector3D([drifter_pos.getX(), drifter_pos.getY(), drifter_pos.getZ()])
            distance = Vector3D.distance(agent_vector, drifter_vector)
            if distance < closest_approach:
                closest_approach = distance
                closest_step = t
            agent_simulated_state = agent.simulation_propagator.propagate(agent_simulated_state.getDate().shiftedBy(agent.delta_t))
            drifter_simulated_state = drifter.simulation_propagator.propagate(drifter_simulated_state.getDate().shiftedBy(drifter.delta_t))
        tca = closest_step * agent.delta_t

        agent.simulation_propagator.setInitialState(agent.simulated_current_state)
        drifter.simulation_propagator.setInitialState(drifter.simulated_current_state)

        # put bodies on tca with simulation propagator
        agent.step_simulation(tca)
        drifter.step_simulation(tca)

        agent_pos = agent.simulated_current_state.getPVCoordinates().getPosition()
        drifter_pos = drifter.simulated_current_state.getPVCoordinates().getPosition()
        agent_vel = agent.simulated_current_state.getPVCoordinates().getVelocity()
        drifter_vel = drifter.simulated_current_state.getPVCoordinates().getVelocity()
        relative_pos = np.array([agent_pos.getX(), agent_pos.getY(), agent_pos.getZ()]) - np.array([drifter_pos.getX(), drifter_pos.getY(), drifter_pos.getZ()])
        relative_vel = np.array([agent_vel.getX(), agent_vel.getY(), agent_vel.getZ()]) - np.array([drifter_vel.getX(), drifter_vel.getY(), drifter_vel.getZ()])
        tca_linear = - float((relative_vel.T @ relative_pos) / (relative_vel.T @ relative_vel))
        chaser_pos = agent.simulated_current_state.getPVCoordinates().getPosition()
        target_pos = drifter.simulated_current_state.getPVCoordinates().getPosition()
        chaser_vel = agent.simulated_current_state.getPVCoordinates().getVelocity()
        target_vel = drifter.simulated_current_state.getPVCoordinates().getVelocity()
        chaser_pos = np.array([chaser_pos.getX(), chaser_pos.getY(), chaser_pos.getZ()])
        target_pos = np.array([target_pos.getX(), target_pos.getY(), target_pos.getZ()])
        chaser_vel = np.array([chaser_vel.getX(), chaser_vel.getY(), chaser_vel.getZ()])
        target_vel = np.array([target_vel.getX(), target_vel.getY(), target_vel.getZ()])
        chaser_cov = np.array(agent.get_simulated_covariance_matrix())[:3, :3]
        target_cov = np.array(drifter.get_simulated_covariance_matrix())[:3, :3]

        # calculate probability of collision (poc)
        poc = Body.poc_rederivation_simulation(chaser_pos, chaser_vel, chaser_cov, target_pos, target_vel, target_cov, tca_linear, agent.radius, drifter.radius)

        # put bodies back at current state
        #agent.step_simulation(-tca)
        #drifter.step_simulation(-tca)

        if poc is None:
            poc = 0.0

        return poc, tca, closest_approach
    
    def predicted_poc(self):

        agent = env.satellites[0]
        drifter = env.drifters[0]

        agent_state = agent.get_cartesian_position() + agent.get_cartesian_velocity()
        drifter_state = drifter.get_cartesian_position() + drifter.get_cartesian_velocity()

        # create copies of current bodies (with no uncertainty)
        agent_params = {"name": "agent",
                       "initial_state": {"x": agent_state[0], "y": agent_state[1], "z": agent_state[2], "x_dot": agent_state[3], "y_dot": agent_state[4], "z_dot": agent_state[5]},
                       "initial_state_uncertainty": {"x": 1e-10, "y": 1e-10, "z": 1e-10, "x_dot": 1e-10, "y_dot": 1e-10, "z_dot": 1e-10},
                       "initial_mass": agent.get_mass(),
                       "radius": agent.radius,
                       "save_steps_info": False,
                       "initial_date": agent.initial_date}
        
        drifter_params = {"name": "drifter",
                          "initial_state": {"x": drifter_state[0], "y": drifter_state[1], "z": drifter_state[2], "x_dot": drifter_state[3], "y_dot": drifter_state[4], "z_dot": drifter_state[5]},
                          "initial_state_uncertainty": {"x": 1e-10, "y": 1e-10, "z": 1e-10, "x_dot": 1e-10, "y_dot": 1e-10, "z_dot": 1e-10},
                          "initial_mass": drifter.get_mass(),
                          "radius": drifter.radius,
                          "save_steps_info": False,
                          "initial_date": drifter.initial_date}

        # simplest gravity model (NewtonianAttraction)
        gravity_model = NewtonianAttraction(MU)

        agent_sim = Body(agent_params)
        drifter_sim = Body(drifter_params)
        agent_sim.delta_t = agent.delta_t
        agent_sim.forces = [gravity_model]
        drifter_sim.delta_t = drifter.delta_t
        drifter_sim.forces = [gravity_model]

        agent_sim.reset()
        drifter_sim.reset()

        # in order to get the correct covariance propagation, set the original covariance matrix
        agent_sim.set_covariance_matrix(agent.get_covariance_matrix())
        drifter_sim.set_covariance_matrix(drifter.get_covariance_matrix())

        # calculate time of closest approach (TCA)
        closest_step = 0
        closest_approach = 1e15
        for t in range(self.t):
            distance = Body.get_distance(agent_sim, drifter_sim)
            if distance < closest_approach:
                closest_approach = distance
                closest_step = t
            agent_sim.step()
            drifter_sim.step()
        tca = closest_step * agent.delta_t

        agent_sim.reset()
        drifter_sim.reset()
        agent_sim.set_covariance_matrix(agent.get_covariance_matrix())
        drifter_sim.set_covariance_matrix(drifter.get_covariance_matrix())
        agent_sim.step(tca)
        drifter_sim.step(tca)

        # calculate probability of collision (poc) at tca
        poc = Body.poc_rederivation(agent_sim, drifter_sim)

        if poc is None:
            poc = 0.0

        return poc, tca, closest_approach

env = ColAvoidanceEnv(params)

time_step = 1
action_std_decay_freq = 5000
action_std_decay_rate = 0.05
update_freq = 256
min_action_std = 0.05
episodes = 10000
steps_per_episode = int(float(60 * 60 * 24 * 2) / params["delta_t"] + 10)
low = np.array([0.0, 0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

print(f'steps_per_episode: {steps_per_episode}')

# load model
# env.satellites[0].ppo.load(".\\missions\\col_avoidance\\model_col_avoidance.pth")
best_score = -1e5
start_episode = 1

for episode in range(start_episode, start_episode + episodes + 1):
    start_step = time_step
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    
    observations, _ = env.reset(42, t=steps_per_episode)
    #print("reset")
    for t in range(1, steps_per_episode + 1):
        try:

            if not env.satellites[0].has_fuel():
                # current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] - 1_000 for satellite in env.satellites}
                break

            #print(observations)

            # select actions with policies
            actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}
            # print("action")

            # apply step
            observations, rewards, terminations, _, _ = env.step(actions, observations)
            #print("step")

            # save rewards and is_terminals
            for satellite in env.satellites:
                satellite.ppo.buffer.rewards.append(rewards[str(satellite)])
                satellite.ppo.buffer.is_terminals.append(terminations[str(satellite)])
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

            # decay action std of ouput action distribution
            if time_step % action_std_decay_freq == 0 and env.satellites[0].ppo.action_std > min_action_std:
                # print("Decaying std...")
                for satellite in env.satellites:
                    satellite.ppo.decay_action_std(action_std_decay_rate, min_action_std)

            time_step += 1

            if terminations[env.agents[0]]:
                break
            
        except Exception as e:
            # if there's a propagation error, remove last values from each satellite agent buffer
            for satellite in env.satellites:
                satellite.ppo.buffer.states.pop()
                satellite.ppo.buffer.actions.pop()
                satellite.ppo.buffer.logprobs.pop()
            traceback.print_exc()
            break

        if params["interface"]["show"]:
            env.render()

        if time_step % update_freq == 0:
            # print("Updating...")
            for satellite in env.satellites:
                satellite.ppo.update_gae()

    # set the last experience termination flag to True
    if len(env.satellites[0].ppo.buffer.is_terminals) > 0:
        env.satellites[0].ppo.buffer.is_terminals[-1] = True

    # show scores at the end of episode
    for satellite in env.satellites:
        score = current_ep_rewards[str(satellite)]
        print(f'Episode {episode}: {score}, Fuel: {satellite.get_fuel()}, Std: {satellite.ppo.action_std}')
        if score > best_score:
            print(f'>>>>>>>> Best score of {score} found. Saving the model.')
            best_score = score
            satellite.ppo.save(f"model_checkpoint.pth")