import numpy as np
from env import OrbitZoo
from bodies import Satellite
from rl_algorithms.ppo import PPO
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU
from math import radians
import traceback

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType
from org.orekit.orbits import KeplerianOrbit

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 310.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 10,
                "gamma": 0.99,
                "clip": 0.1,
                "action_std_init": 0.5,
                "state_dim_actor": 7,
                "state_dim_critic": 7,
                "action_space": [500, np.pi, 2*np.pi, 1.0],
             }},
        ],
        "delta_t": 5.0,
        "forces": {
            "gravity_model":  "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 0.5,
                }
            },
            "drag": {
                "active": False,
                "drag_coefficients": {
                    "agent": 0.5,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "zoom": 20.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 50,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 5000,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 70,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                {"a": 2030.0e3, "e": 0.01, "i": 5.0, "pa": 20.0, "raan": 20.0, "color": (255, 0, 255)},
            ],
        }
    }

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

class HohmannEnvironment(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        self.target_elements = [2030.0e3, 0.01, 5.0, 20.0, 20.0, 10.0]
        self.tolerance = [100.0, 0.005, 0.005, 0.001, 0.001]
        print("target:\t\t\t", self.target_elements)
        print("tolerance:\t\t", self.tolerance)

    def create_bodies(self, params):
        self.drifters = []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def normalize_states(self, states):
        for satellite in self.satellites:
            state = states[satellite.name]
            state[0] /= self.target[0]
            states[satellite.name] = state
        return states

    def reset(self, seed=None, options=None):
        
        super().reset(seed)

        self.hit_target = False
        target = self.target_elements
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target = [orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy(), orbit.getHx(), orbit.getHy()]

        observations = self.normalize_states({satellite.name: satellite.get_state() for satellite in self.satellites})
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        # print(scaled_actions)

        for satellite in env.satellites:
            action = scaled_actions[str(satellite)]
            if action[3] < 0.5:
                scaled_actions[str(satellite)] = np.array([0.0,0.0,0.0])
            else:
                scaled_actions[str(satellite)] = action[:-1]

        states_before = {str(satellite): satellite.get_state() for satellite in env.satellites}

        states = {satellite.name: satellite.step(scaled_actions[satellite.name]) for satellite in env.satellites}
        rewards, terminations = self.rewards(states, clipped_actions, states_before)

        terminations = {satellite.name: False for satellite in env.satellites}

        observations = self.normalize_states(states)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos

    def rewards(self, observations, actions, observations_before):

        agent_name = self.satellites[0].name
        state = observations[agent_name]
        state_before = observations_before[agent_name]
        target = self.target
        action = actions[agent_name]
        scaled_action = ((action + 1) / 2) * high

        # Extract current values from the observation
        current_a = state[0]
        current_ex = state[1]
        current_ey = state[2]
        current_hx = state[3]
        current_hy = state[4]

        before_a = state_before[0]
        before_ex = state_before[1]
        before_ey = state_before[2]
        before_hx = state_before[3]
        before_hy = state_before[4]

        # Extract target values from the observation
        target_a = target[0]
        target_ex = target[1]
        target_ey = target[2]
        target_hx = target[3]
        target_hy = target[4]

        # Differences in values
        a_diff = abs(target_a - current_a)
        ex_diff = abs(target_ex - current_ex)
        ey_diff = abs(target_ey - current_ey)
        hx_diff = abs(target_hx - current_hx)
        hy_diff = abs(target_hy - current_hy)

        a_diff_before = abs(target_a - before_a)
        ex_diff_before = abs(target_ex - before_ex)
        ey_diff_before = abs(target_ey - before_ey)
        hx_diff_before = abs(target_hx - before_hx)
        hy_diff_before = abs(target_hy - before_hy)

        tolerance = self.tolerance

        if a_diff <= tolerance[0] and ex_diff <= tolerance[1] and ey_diff <= tolerance[2] and hx_diff <= tolerance[3] and hy_diff <= tolerance[4]:
            if not self.hit_target:
                print("Agent reached target. Saving the model.")
                self.satellites[0].ppo.save(f"model_checkpoint_hit_target.pth")
                self.hit_target = True
            return {agent_name: 0}, {agent_name: False}

        alpha_a = 1000
        alpha_ex = 1
        alpha_ey = 1
        alpha_hx = 10
        alpha_hy = 10

        r_a = (a_diff_before - a_diff) / target_a if a_diff > tolerance[0] else 0
        r_ex = (ex_diff_before - ex_diff) / target_ex if ex_diff > tolerance[1] else 0
        r_ey = (ey_diff_before - ey_diff) / target_ey if ey_diff > tolerance[2] else 0
        r_hx = (hx_diff_before - hx_diff) / target_hx if hx_diff > tolerance[3] else 0
        r_hy = (hy_diff_before - hy_diff) / target_hy if hy_diff > tolerance[4] else 0
        improvement = alpha_a * r_a + alpha_ex * r_ex + alpha_ey * r_ey + alpha_hx * r_hx + alpha_hy * r_hy

        thrust_indicator = 1 if scaled_action[3] > 0.5 else 0

        alpha_1 = 1
        alpha_2 = 0

        reward = thrust_indicator * ( alpha_1 * ((action[0] + 1) / 2) * improvement - alpha_2 * (action[1] + 1) / 2 ) 

        return {agent_name: reward}, {agent_name: False}

env = HohmannEnvironment(params)

time_step = 1
action_std_decay_freq = 10000
action_std_decay_rate = 0.05
update_freq = 4096
min_action_std = 0.05
episodes = 10000
steps_per_episode = 1000
low = np.array([0.0, 0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

# load model
# env.satellites[0].ppo.load(".\\missions\\hohmann\\model_hohmann_experiment1.pth")
# env.satellites[0].ppo.load(".\\missions\\hohmann\\model_hohmann_experiment2.pth")
best_score = -1e7
start_episode = 1

for episode in range(start_episode, start_episode + episodes + 1):
    start_step = time_step
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    observations, _ = env.reset(42)
    for t in range(1, steps_per_episode + 1):
        try:

            if not env.satellites[0].has_fuel():
                break

            # select actions with policies
            actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}

            # if t == 1:
            #     actions = {'agent': np.array([0.232, -1.0, -1.0, 1.0])}
            # elif t == 766:
            #     actions = {'agent': np.array([0.2315, -1.0, -1.0, 1.0])}
            # else:
            #     actions = {'agent': np.array([-1.0, -1.0, -1.0, -1.0])}

            # actions = {'agent': np.array([-1.0, -1.0, -1.0, -1.0])}

            # apply step
            observations, rewards, terminations, _, _ = env.step(actions)
            
            # save rewards and is_terminals
            for satellite in env.satellites:
                satellite.ppo.buffer.rewards.append(rewards[str(satellite)])
                satellite.ppo.buffer.is_terminals.append(terminations[str(satellite)])
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

            # decay action std of ouput action distribution
            if time_step % action_std_decay_freq == 0 and env.satellites[0].ppo.action_std > min_action_std:
                # print("Decaying std...")
                for satellite in env.satellites:
                    satellite.ppo.decay_action_std(action_std_decay_rate, min_action_std)

            time_step += 1

            if terminations[env.agents[0]]:
                break
            
        except Exception as e:
            # if there's a propagation error, remove last values from each satellite agent buffer
            for satellite in env.satellites:
                satellite.ppo.buffer.states.pop()
                satellite.ppo.buffer.actions.pop()
                satellite.ppo.buffer.logprobs.pop()
            traceback.print_exc()
            break

        if params["interface"]["show"]:
            env.render()

        if time_step % update_freq == 0:
            # print("Updating...")
            for satellite in env.satellites:
                satellite.ppo.update_gae()

    # set the last experience termination flag to True
    if len(env.satellites[0].ppo.buffer.is_terminals) > 0:
        env.satellites[0].ppo.buffer.is_terminals[-1] = True
        # env.satellites[0].ppo.buffer.rewards[-1] = -1
        
    # print("closest reward:", env.best_episode_reward)
    # print("closest state:\t", env.best_state)
    # print("closest distance:", env.best_episode_distance)

    # show scores at the end of episode
    for satellite in env.satellites:
        score = current_ep_rewards[str(satellite)]
        print(f'Episode {episode}: {score}, Fuel: {satellite.get_fuel()}, Std: {satellite.ppo.action_std}')
        # print(f'Episode {episode}: {score}, Fuel: {satellite.get_fuel()}, Std: {np.exp(satellite.ppo.policy.log_std.cpu().detach().numpy())}, steps: {t}')
        if score > best_score:
            print(f'>>>>>>>> Best score of {score} found. Saving the model.')
            best_score = score
            satellite.ppo.save(f"model_checkpoint.pth")