import numpy as np
from env import OrbitZoo
from bodies import Satellite
from rl_algorithms.ppo import PPO
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU
from math import radians
import pandas as pd

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType

params = {
        "satellites": [
            {"name": "experiment_1",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 310.0,
             "radius": 5.0,
             "save_steps_info": False,
             "color": (255, 0, 0),
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 10,
                "gamma": 0.95,
                "clip": 0.5,
                "action_std_init": 0.01,
                "state_dim_actor": 7,
                "state_dim_critic": 7,
                "action_space": [500, np.pi, 2*np.pi, 1.0],
             }},
             {"name": "experiment_2",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 310.0,
             "radius": 5.0,
             "save_steps_info": False,
             "color": (255, 105, 0),
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 5,
                "gamma": 0.95,
                "clip": 0.5,
                "action_std_init": 0.2,
                "state_dim_actor": 7,
                "state_dim_critic": 7,
                "action_space": [500, np.pi, 2*np.pi, 1.0],
             }},
             {"name": "optimal",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 310.0,
             "radius": 5.0,
             "save_steps_info": False,
             "color": (0, 255, 0),
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 10,
                "gamma": 0.95,
                "clip": 0.5,
                "action_std_init": 0.05,
                "state_dim_actor": 7,
                "state_dim_critic": 7,
                "action_space": [500, np.pi, 2*np.pi, 1.0],
             }},
        ],
        "delta_t": 5.0,
        "forces": {
            "gravity_model": "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 0.5,
                }
            },
            "drag": {
                "active": False,
                "drag_coefficients": {
                    "agent": 0.5,
                }
            }
        },
        "interface": {
            "show": True,
            "delay_ms": 0,
            "zoom": 30.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 50,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 1000,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 60,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                {"a": 2030.0e3, "e": 0.01, "i": 5.0, "pa": 20.0, "raan": 20.0, "color": (0, 255, 0)},
            ],
        }
    }

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

class HohmannEnvironment(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        self.target_elements = [2030.0e3, 0.01, 5.0, 20.0, 20.0, 10.0]
        self.tolerance = [100.0, 0.005, 0.005, 0.001, 0.001]
        print("target:\t\t\t", self.target_elements)
        print("tolerance:\t\t", self.tolerance)

    def create_bodies(self, params):
        self.drifters = []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def normalize_states(self, states):
        for satellite in self.satellites:
            state = states[satellite.name]
            state[0] /= self.target[0]
            states[satellite.name] = state
        return states

    def reset(self, seed=None, options=None):
        
        for body in self.drifters + self.satellites:
            body.reset(seed)

        target = self.target_elements
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target = [orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy(), orbit.getHx(), orbit.getHy()]
        # print("--------------------")
        print("target:\t\t", self.target)
        # print("initial state:\t", self.satellites[0].get_state())

        observations = self.normalize_states({satellite.name: satellite.get_state() for satellite in self.satellites})
        # observations = {satellite.name: satellite.get_state() for satellite in self.satellites}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        scaled_actions = actions
        # clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        # scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        # print(scaled_actions)

        for satellite in env.satellites:
            action = scaled_actions[str(satellite)]
            if action[3] < 0.5:
                scaled_actions[str(satellite)] = np.array([0.0,0.0,0.0])
            else:
                scaled_actions[str(satellite)] = action[:-1]

        states_before = {str(satellite): satellite.get_state() for satellite in env.satellites}

        states = {satellite.name: satellite.step(scaled_actions[satellite.name]) for satellite in env.satellites}
        # rewards, terminations = self.rewards(states, clipped_actions, states_before)

        rewards = {satellite.name: 0 for satellite in env.satellites}
        terminations = {satellite.name: False for satellite in env.satellites}

        observations = self.normalize_states(states)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos

    def rewards(self, observations, actions, observations_before):

        rewards = {}
        truncations = {}

        for satellite in self.satellites:

            agent_name = satellite.name
            state = observations[agent_name]
            state_before = observations_before[agent_name]
            target = self.target
            action = actions[agent_name]
            scaled_action = ((action + 1) / 2) * high

            # Extract current values from the observation
            current_a = state[0]
            current_ex = state[1]
            current_ey = state[2]
            current_hx = state[3]
            current_hy = state[4]

            before_a = state_before[0]
            before_ex = state_before[1]
            before_ey = state_before[2]
            before_hx = state_before[3]
            before_hy = state_before[4]

            # Extract target values from the observation
            target_a = target[0]
            target_ex = target[1]
            target_ey = target[2]
            target_hx = target[3]
            target_hy = target[4]

            # Differences in values
            a_diff = abs(target_a - current_a)
            ex_diff = abs(target_ex - current_ex)
            ey_diff = abs(target_ey - current_ey)
            hx_diff = abs(target_hx - current_hx)
            hy_diff = abs(target_hy - current_hy)

            a_diff_before = abs(target_a - before_a)
            ex_diff_before = abs(target_ex - before_ex)
            ey_diff_before = abs(target_ey - before_ey)
            hx_diff_before = abs(target_hx - before_hx)
            hy_diff_before = abs(target_hy - before_hy)

            tolerance = self.tolerance

            if a_diff <= tolerance[0] and ex_diff <= tolerance[1] and ey_diff <= tolerance[2] and hx_diff <= tolerance[3] and hy_diff <= tolerance[4]:
                if not self.hit_target:
                    print("Agent reached target. Saving the model.")
                    self.satellites[0].ppo.save(f"model_checkpoint_hit_target.pth")
                    self.hit_target = True
                return {agent_name: 0}, {agent_name: False}

            alpha_a = 1000
            alpha_ex = 1
            alpha_ey = 1
            alpha_hx = 10
            alpha_hy = 10

            r_a = (a_diff_before - a_diff) / target_a if a_diff > tolerance[0] else 0
            r_ex = (ex_diff_before - ex_diff) / target_ex if ex_diff > tolerance[1] else 0
            r_ey = (ey_diff_before - ey_diff) / target_ey if ey_diff > tolerance[2] else 0
            r_hx = (hx_diff_before - hx_diff) / target_hx if hx_diff > tolerance[3] else 0
            r_hy = (hy_diff_before - hy_diff) / target_hy if hy_diff > tolerance[4] else 0
            improvement = alpha_a * r_a + alpha_ex * r_ex + alpha_ey * r_ey + alpha_hx * r_hx + alpha_hy * r_hy

            thrust_indicator = 1 if scaled_action[3] > 0.5 else 0

            alpha_1 = 1
            alpha_2 = 0.5

            reward = thrust_indicator * ( alpha_1 * ((action[0] + 1) / 2) * improvement - alpha_2 * (action[1] + 1) / 2 ) 

            rewards[agent_name] = reward
            truncations[agent_name] = False
        
        return rewards, truncations

env = HohmannEnvironment(params)

env.satellites[0].ppo.policy.eval()
env.satellites[0].ppo.policy_old.eval()
env.satellites[1].ppo.policy.eval()
env.satellites[1].ppo.policy_old.eval()

time_step = 1
action_std_decay_freq = 10000
action_std_decay_rate = 0.05
update_freq = 512
min_action_std = 0.1
episodes = 1000
steps_per_episode = 3000
low = np.array([0.0, 0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

# load model
# env.satellites[0].ppo.load(".\\missions\\hohmann\\model_hohmann_experiment1.pth")
# env.satellites[0].ppo.load(".\\missions\\hohmann\\model_hohmann_experiment2.pth")
best_score = -24

hit_target = False

df1 = pd.read_csv(".\\missions\\hohmann\\hohmann_experiment1.csv")
df2 = pd.read_csv(".\\missions\\hohmann\\hohmann_experiment2.csv")
df3 = pd.read_csv(".\\missions\\hohmann\\hohmann_optimal.csv")

current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
observations, _ = env.reset()

# for _ in range(200):
#     env.render()

pos_optimal = env.satellites[0].get_cartesian_position()
pos_exp = env.satellites[1].get_cartesian_position()

data = [{
    'x_optimal': pos_optimal[0],
    'y_optimal': pos_optimal[1],
    'z_optimal': pos_optimal[2],
    'x_exp': pos_exp[0],
    'y_exp': pos_exp[1],
    'z_exp': pos_exp[2],
}]

for t in range(1, steps_per_episode + 1):
    try:

        if not env.satellites[0].has_fuel():
            break

        if t < 1000:
            actions = {
                'experiment_1': np.array(df1.iloc[t][['M', 'θ', 'Φ', 'δ']]),
                'experiment_2': np.array(df2.iloc[t][['M', 'θ', 'Φ', 'δ']]),
                'optimal': np.array(df3.iloc[t][['M', 'θ', 'Φ', 'δ']])
            }
        else:
            actions = {
                'experiment_1': np.array([0.0, 0.0, 0.0, 0.0]),
                'experiment_2': np.array([0.0, 0.0, 0.0, 0.0]),
                'optimal': np.array([0.0, 0.0, 0.0, 0.0])
            }

        # actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}
        observations, rewards, terminations, _, _ = env.step(actions)

        # pos_optimal = env.satellites[0].get_cartesian_position()
        # pos_exp = env.satellites[1].get_cartesian_position()
        # data.append({
        #     'x_optimal': pos_optimal[0],
        #     'y_optimal': pos_optimal[1],
        #     'z_optimal': pos_optimal[2],
        #     'x_exp': pos_exp[0],
        #     'y_exp': pos_exp[1],
        #     'z_exp': pos_exp[2],
        # })

        # save rewards and is_terminals
        for satellite in env.satellites:
            satellite.ppo.buffer.rewards.append(rewards[str(satellite)])
            satellite.ppo.buffer.is_terminals.append(terminations[str(satellite)])
        current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}
        # if terminations[env.agents[0]]:
        #     print("Agent reached target")
        #     break
    except Exception as e:
        # if there's a propagation error, remove last values from each satellite agent buffer
        for satellite in env.satellites:
            satellite.ppo.buffer.states.pop()
            satellite.ppo.buffer.actions.pop()
            satellite.ppo.buffer.logprobs.pop()
            satellite.ppo.buffer.state_values.pop()
        print(e)
        # traceback.print_exc()
        break
    if params["interface"]["show"]:
        env.render()

print(current_ep_rewards)

# df = pd.DataFrame(data)
# df.to_csv("hohmann_teste.csv")