import numpy as np
from env import OrbitZoo
from bodies import Satellite
from rl_algorithms.ppo import PPO
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU
from math import radians
import pandas as pd

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 310.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 5,
                "gamma": 0.95,
                "clip": 0.5,
                "action_std_init": 0.05,
                "state_dim_actor": 7,
                "state_dim_critic": 7,
                "action_space": [500, np.pi, 2*np.pi, 1.0],
             }},
        ],
        "delta_t": 5.0,
        "forces": {
            "gravity_model": "HolmesFeatherstone",
            "third_bodies": {
                "active": True,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": True,
                "reflection_coefficients": {
                    "agent": 1.0,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    "agent": 1.0,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": True,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 0, 0),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": True,
                "show_thrust": True,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
            },
        }
    }

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

class HohmannEnvironment(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        self.target_elements = [2030.0e3, 0.01, 5.0, 20.0, 20.0, 10.0]
        self.tolerance = [100.0, 0.005, 0.005, 0.001, 0.001]
        print("target:\t\t\t", self.target_elements)
        print("tolerance:\t\t", self.tolerance)

    def create_bodies(self, params):
        self.drifters = []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def normalize_states(self, states):
        for satellite in self.satellites:
            state = states[satellite.name]
            state[0] /= self.target[0]
            states[satellite.name] = state
        return states

    def reset(self, seed=None, options=None):
        
        super().reset(seed)

        self.hit_target = False
        target = self.target_elements
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target = [orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy(), orbit.getHx(), orbit.getHy()]

        observations = self.normalize_states({satellite.name: satellite.get_state() for satellite in self.satellites})
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        # print(scaled_actions)

        for satellite in env.satellites:
            action = scaled_actions[str(satellite)]
            if action[3] < 0.5:
                scaled_actions[str(satellite)] = np.array([0.0,0.0,0.0])
            else:
                scaled_actions[str(satellite)] = action[:-1]

        states_before = {str(satellite): satellite.get_state() for satellite in env.satellites}

        states = {satellite.name: satellite.step(scaled_actions[satellite.name]) for satellite in env.satellites}
        rewards, terminations = self.rewards(states, clipped_actions, states_before)

        terminations = {satellite.name: False for satellite in env.satellites}

        observations = self.normalize_states(states)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos

    def rewards(self, observations, actions, observations_before):

        agent_name = self.satellites[0].name
        state = observations[agent_name]
        state_before = observations_before[agent_name]
        target = self.target
        action = actions[agent_name]
        scaled_action = ((action + 1) / 2) * high

        # Extract current values from the observation
        current_a = state[0]
        current_ex = state[1]
        current_ey = state[2]
        current_hx = state[3]
        current_hy = state[4]

        before_a = state_before[0]
        before_ex = state_before[1]
        before_ey = state_before[2]
        before_hx = state_before[3]
        before_hy = state_before[4]

        # Extract target values from the observation
        target_a = target[0]
        target_ex = target[1]
        target_ey = target[2]
        target_hx = target[3]
        target_hy = target[4]

        # Differences in values
        a_diff = abs(target_a - current_a)
        ex_diff = abs(target_ex - current_ex)
        ey_diff = abs(target_ey - current_ey)
        hx_diff = abs(target_hx - current_hx)
        hy_diff = abs(target_hy - current_hy)

        a_diff_before = abs(target_a - before_a)
        ex_diff_before = abs(target_ex - before_ex)
        ey_diff_before = abs(target_ey - before_ey)
        hx_diff_before = abs(target_hx - before_hx)
        hy_diff_before = abs(target_hy - before_hy)

        tolerance = self.tolerance

        if a_diff <= tolerance[0] and ex_diff <= tolerance[1] and ey_diff <= tolerance[2] and hx_diff <= tolerance[3] and hy_diff <= tolerance[4]:
            if not self.hit_target:
                print("Agent reached target. Saving the model.")
                self.satellites[0].ppo.save(f"model_checkpoint_hit_target.pth")
                self.hit_target = True
            return {agent_name: 0}, {agent_name: False}

        alpha_a = 1000
        alpha_ex = 1
        alpha_ey = 1
        alpha_hx = 10
        alpha_hy = 10

        r_a = (a_diff_before - a_diff) / target_a if a_diff > tolerance[0] else 0
        r_ex = (ex_diff_before - ex_diff) / target_ex if ex_diff > tolerance[1] else 0
        r_ey = (ey_diff_before - ey_diff) / target_ey if ey_diff > tolerance[2] else 0
        r_hx = (hx_diff_before - hx_diff) / target_hx if hx_diff > tolerance[3] else 0
        r_hy = (hy_diff_before - hy_diff) / target_hy if hy_diff > tolerance[4] else 0
        improvement = alpha_a * r_a + alpha_ex * r_ex + alpha_ey * r_ey + alpha_hx * r_hx + alpha_hy * r_hy

        thrust_indicator = 1 if scaled_action[3] > 0.5 else 0

        alpha_1 = 1
        alpha_2 = 0.5

        reward = thrust_indicator * ( alpha_1 * ((action[0] + 1) / 2) * improvement - alpha_2 * (action[1] + 1) / 2 ) 

        return {agent_name: reward}, {agent_name: False}

env = HohmannEnvironment(params)

time_step = 1
action_std_decay_freq = 10000
action_std_decay_rate = 0.05
update_freq = 512
min_action_std = 0.1
episodes = 1000
steps_per_episode = 1000
low = np.array([0.0, 0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

# load model
# env.satellites[0].ppo.load(".\\missions\\hohmann\\model_hohmann_experiment1.pth")
env.satellites[0].ppo.load(".\\missions\\hohmann\\model_hohmann_experiment2.pth")
best_score = -24

hit_target = False

while not hit_target:
    # hit_target = True
    print("reset")

    data = []
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    observations, _ = env.reset(42)

    cartesian_position = env.satellites[0].get_cartesian_position()
    cartesian_velocity = env.satellites[0].get_cartesian_velocity()
    state = env.satellites[0].get_state()
    data.append({
                    'step': 0,
                    'a': state[0],
                    'ex': state[1],
                    'ey': state[2],
                    'hx': state[3],
                    'hy': state[4],
                    'lm': state[5],
                    'M': 0,
                    'θ': 0,
                    'Φ': 0,
                    'δ': 0,
                    'fuel': env.satellites[0].get_fuel(),
                    'vx': cartesian_velocity[0],
                    'vy': cartesian_velocity[1],
                    'vz': cartesian_velocity[2]
    })

    for t in range(1, steps_per_episode + 1):
        try:

            if not env.satellites[0].has_fuel():
                break

            # if t == 1:
            #     actions = {'agent': np.array([0.232, -1.0, -1.0, 1.0])}
            # elif t == 766:
            #     actions = {'agent': np.array([0.2315, -1.0, -1.0, 1.0])}
            # else:
            #     actions = {'agent': np.array([-1.0, -1.0, -1.0, -1.0])}

            actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}
            observations, rewards, terminations, _, _ = env.step(actions)
            observation = observations[env.satellites[0].name]
            scaled_actions = {str(satellite): ((np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) + 1) / 2) * high for satellite in env.satellites}
            action = scaled_actions[env.satellites[0].name]
            cartesian_position = env.satellites[0].get_cartesian_position()
            cartesian_velocity = env.satellites[0].get_cartesian_velocity()
            state = env.satellites[0].get_state()

            # check if hit target
            tolerance = env.tolerance
            target = env.target
            a_diff = abs(target[0] - state[0])
            ex_diff = abs(target[1] - state[1])
            ey_diff = abs(target[2] - state[2])
            hx_diff = abs(target[3] - state[3])
            hy_diff = abs(target[4] - state[4])
            if a_diff <= tolerance[0] and ex_diff <= tolerance[1] and ey_diff <= tolerance[2] and hx_diff <= tolerance[3] and hy_diff <= tolerance[4]:
                hit_target = True
                print("Target hit")

            data.append({
                    'step': t,
                    'a': state[0],
                    'ex': state[1],
                    'ey': state[2],
                    'hx': state[3],
                    'hy': state[4],
                    'lm': state[5],
                    'M': action[0],
                    'θ': action[1],
                    'Φ': action[2],
                    'δ': action[3],
                    'fuel': env.satellites[0].get_fuel(),
                    'vx': cartesian_velocity[0],
                    'vy': cartesian_velocity[1],
                    'vz': cartesian_velocity[2]
            })
            # save rewards and is_terminals
            for satellite in env.satellites:
                satellite.ppo.buffer.rewards.append(rewards[str(satellite)])
                satellite.ppo.buffer.is_terminals.append(terminations[str(satellite)])
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}
        except Exception as e:
            # if there's a propagation error, remove last values from each satellite agent buffer
            for satellite in env.satellites:
                satellite.ppo.buffer.states.pop()
                satellite.ppo.buffer.actions.pop()
                satellite.ppo.buffer.logprobs.pop()
                satellite.ppo.buffer.state_values.pop()
            print(e)
            # traceback.print_exc()
            break
        if params["interface"]["show"]:
            env.render()

    env.satellites[0].ppo.buffer.clear()

    print(data[-1]['a'] - 6378000)

print(current_ep_rewards)

df = pd.DataFrame(data)
df.to_csv("hohmann_results_new_2.csv")