import numpy as np
from env import OrbitZoo
from bodies import Satellite
from rl_algorithms.ppo import PPO
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU
from math import radians
import traceback
import random
import torch
from torch.utils.tensorboard import SummaryWriter

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType

params = {
        "satellites": [
            {"name": "agent_1",
             "initial_state": {"x": 32299497.899668593, "y": 27102496.774823245, "z": 0.0, "x_dot": -1976.3573913582284, "y_dot": 2355.3310214012895, "z_dot": 0.0},
             "initial_state_uncertainty": {"x": 0.000001, "y": 0.000001, "z": 0.000001, "x_dot": 0.000001, "y_dot": 0.000001, "z_dot": 0.000001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.00001,
                "lr_critic": 0.0001,
                "gae_lambda": 0.95,
                "epochs": 3,
                "gamma": 0.99,
                "clip": 0.2,
                "action_std_init": 0.5,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [5, 2*np.pi],
             }},
            {"name": "agent_2",
             "initial_state": {"x": 32299497.899668593, "y": 27102496.774823245, "z": 0.0, "x_dot": -1976.3573913582284, "y_dot": 2355.3310214012895, "z_dot": 0.0},
             "initial_state_uncertainty": {"x": 0.000001, "y": 0.000001, "z": 0.000001, "x_dot": 0.000001, "y_dot": 0.000001, "z_dot": 0.000001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.00001,
                "lr_critic": 0.0001,
                "gae_lambda": 0.95,
                "epochs": 3,
                "gamma": 0.99,
                "clip": 0.2,
                "action_std_init": 0.5,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [5, 2*np.pi],
             }},
            {"name": "agent_3",
             "initial_state": {"x": 32299497.899668593, "y": 27102496.774823245, "z": 0.0, "x_dot": -1976.3573913582284, "y_dot": 2355.3310214012895, "z_dot": 0.0},
             "initial_state_uncertainty": {"x": 0.000001, "y": 0.000001, "z": 0.000001, "x_dot": 0.000001, "y_dot": 0.000001, "z_dot": 0.000001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.00001,
                "lr_critic": 0.0001,
                "gae_lambda": 0.95,
                "epochs": 3,
                "gamma": 0.99,
                "clip": 0.2,
                "action_std_init": 0.5,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [5, 2*np.pi],
             }},
            {"name": "agent_4",
             "initial_state": {"x": 32299497.899668593, "y": 27102496.774823245, "z": 0.0, "x_dot": -1976.3573913582284, "y_dot": 2355.3310214012895, "z_dot": 0.0},
             "initial_state_uncertainty": {"x": 0.000001, "y": 0.000001, "z": 0.000001, "x_dot": 0.000001, "y_dot": 0.000001, "z_dot": 0.000001},
             "initial_mass": 200.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.00001,
                "lr_critic": 0.0001,
                "gae_lambda": 0.95,
                "epochs": 3,
                "gamma": 0.99,
                "clip": 0.2,
                "action_std_init": 0.5,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [5, 2*np.pi],
             }},
        ],
        "delta_t": 360.0,
        "forces": {
            "gravity_model":  "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 0.5,
                }
            },
            "drag": {
                "active": False,
                "drag_coefficients": {
                    "agent": 0.5,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "zoom": 1.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 50,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 50,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 70,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                {"a": 42164e3 - 6378e3, "e": 0.0001, "i": 0.0001, "pa": 20.0, "raan": 20.0, "color": (255, 0, 255)},
            ],
        }
    }

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def get_state(self):
        pass
    
    def __get_random_orbit__(self, seed = None):
        orbit = super().__get_random_orbit__()
        orbit = KeplerianOrbit(orbit.getPVCoordinates(), orbit.getFrame(), orbit.getDate(), orbit.getMu())
        orbit = KeplerianOrbit(orbit.getA(), orbit.getE(), orbit.getI(), orbit.getPerigeeArgument(), orbit.getRightAscensionOfAscendingNode(), radians(float(random.randint(0, 360))), PositionAngleType.MEAN, orbit.getFrame(), orbit.getDate(), orbit.getMu())
        return orbit
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

class MARLGEOEnv(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        # geostationary orbit
        target = [42_164_000 - EARTH_RADIUS, 0.01, 0.01, 20.0, 20.0, 0.0]
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target_orbit = np.array([orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy()])
        
        target_anomalies = np.array([-np.pi, -1/2*np.pi, 0.0, 1/2*np.pi])
        tolerance_anomaly = 1/8*np.pi
        self.tolerances = [[anomaly - tolerance_anomaly, anomaly + tolerance_anomaly] for anomaly in target_anomalies]

    def create_bodies(self, params):
        self.drifters = []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None):
        
        super().reset(seed)

        observations = {}
        anomalies = [satellite.get_equinoctial_position()[5] + np.pi for satellite in self.satellites]
        for i, satellite in enumerate(self.satellites):
            satellite_anomaly = anomalies[i]
            remaining_anomalies = [anomaly for j, anomaly in enumerate(anomalies) if j != i]
            observations[satellite.name] = satellite.get_equinoctial_position()[:3] + [satellite.get_fuel()] + [satellite_anomaly] + remaining_anomalies

        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        # print(actions)
        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1], [1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): list(((clipped_actions[satellite.name] + 1) / 2) * high) + [0] for satellite in env.satellites}

        # print(scaled_actions)

        # states_before = {str(satellite): satellite.get_state() for satellite in env.satellites}

        states = {satellite.name: satellite.step(scaled_actions[satellite.name]) for satellite in env.satellites}

        observations = {}
        anomalies = [satellite.get_equinoctial_position()[5] + np.pi for satellite in self.satellites]
        for i, satellite in enumerate(self.satellites):
            satellite_anomaly = anomalies[i]
            remaining_anomalies = [anomaly for j, anomaly in enumerate(anomalies) if j != i]
            observations[satellite.name] = satellite.get_equinoctial_position()[:3] + [satellite.get_fuel()] + [satellite_anomaly] + remaining_anomalies

        #rewards = {satellite.name: 0 for satellite in env.satellites}
        #terminations = {satellite.name: False for satellite in env.satellites}
        rewards, terminations = self.rewards(observations, scaled_actions)
        # print(rewards)
        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos
    
    def rewards_old(self, observations, actions):

        rewards = {}
        truncations = {}

        target_elements = np.array([42164e3, 0.0, 0.0])
        target_distance = np.pi / 2

        # calculate anomaly separation penalty (same penalty for all agents)
        anomalies = observations[self.satellites[0].name][-4:]
        anomaly_penalty = 0
        total_pairs = 0
        for i, anomaly_i in enumerate(anomalies):
            for j, anomaly_j in enumerate(anomalies):
                if i < j:
                    angular_difference = np.abs(anomaly_i - anomaly_j) % (2 * np.pi)
                    angular_difference = min(angular_difference, 2 * np.pi - angular_difference)
                    # print(f"Angular difference between {i} and {j}: {angular_difference}")
                    if angular_difference < target_distance:
                        # print("penalty")
                        anomaly_penalty += (target_distance - angular_difference) / target_distance
                    total_pairs += 1
        avg_anomaly_penalty = anomaly_penalty / total_pairs if total_pairs > 0 else 0
        # print(f'penalty: {avg_anomaly_penalty}, anomalies: {anomalies}')

        for satellite in self.satellites:
            # calculate elements (individual penalty for each agent)
            elements = np.array(observations[str(satellite)][:3])
            elements_diff = np.abs(elements - target_elements)
            elements_diff[0] /= target_elements[0]
            weights = np.array([1, 1, 1])
            elements_penalty = weights @ elements_diff

            # calculate action penalty
            action_penalty = actions[satellite.name][0]

            # reward = - (1e-6*altitude_penalty + 1e-5*avg_anomaly_penalty + 100*action_penalty)
            # reward = - (1e-10*altitude_penalty + 1*action_penalty + 1e3*avg_anomaly_penalty)
            reward = - (1e1*elements_penalty + 1e-10*action_penalty + 1e-6*avg_anomaly_penalty)
            # print(f'- ({1e-4*elements_penalty} + {1e-5*action_penalty} + {1e-3*avg_anomaly_penalty}) = {reward}')
            rewards[satellite.name] = reward
            truncations[satellite.name] = False

        return rewards, truncations

    def rewards(self, observations, actions):

        rewards = {}
        truncations = {}

        target_altitude = 42164e3
        target_distance = 2 * np.pi / len(self.satellites)

        # calculate anomaly separation penalty (same penalty for all agents)
        anomalies = observations[self.satellites[0].name][-4:]
        anomaly_penalty = 0
        total_pairs = 0
        for i, anomaly_i in enumerate(anomalies):
            for j, anomaly_j in enumerate(anomalies):
                if i < j:
                    angular_difference = np.abs(anomaly_i - anomaly_j) % (2 * np.pi)
                    angular_difference = min(angular_difference, 2 * np.pi - angular_difference)
                    # print(f"Angular difference between {i} and {j}: {angular_difference}")
                    if angular_difference < target_distance:
                        # print("penalty")
                        anomaly_penalty += (target_distance - angular_difference) / target_distance
                    total_pairs += 1
        avg_anomaly_penalty = anomaly_penalty / total_pairs if total_pairs > 0 else 0
        # print(f'penalty: {avg_anomaly_penalty}, anomalies: {anomalies}')

        for satellite in self.satellites:
            # calculate altitude penalty (individual penalty for each agent)
            altitude = satellite.get_altitude()
            altitude_penalty = np.abs(altitude - target_altitude)
            # calculate action penalty
            action_penalty = actions[satellite.name][0]

            # reward = - (1e-8*altitude_penalty + 1e-8*avg_anomaly_penalty + 1e1*action_penalty)

            # reward = - (1e-8*altitude_penalty + 1e-4*avg_anomaly_penalty + 1e1*action_penalty) # works nice! exp 20
            reward = - (1e-8*altitude_penalty + 1e-2*avg_anomaly_penalty + 1e1*action_penalty) # works nice! exp 21
            # reward = - (1e-7*altitude_penalty + 1e1*avg_anomaly_penalty + 1e-1*action_penalty)

            # reward = - (1e-10*altitude_penalty + 1*action_penalty + 1e3*avg_anomaly_penalty)
            # reward = - (1e-6*altitude_penalty + 1e-5*action_penalty + 1e-3*avg_anomaly_penalty)
            # print(f'- ({1e-7*altitude_penalty} + {1e1*avg_anomaly_penalty} + {1e-1*action_penalty}) = {reward}')
            rewards[satellite.name] = reward
            truncations[satellite.name] = False

        return rewards, truncations
    
    def rewards_old(self, observations):

        rewards = {}
        truncations = {}

        target_orbit = self.target_orbit

        for satellite in self.satellites:
            observation = observations[satellite.name]
            orbit = observation[:3]
            orbit_penalty = np.ones(3) @ ((orbit - target_orbit)**2 / target_orbit)
            reward = - (orbit_penalty / 1e7)
            rewards[satellite.name] = reward
            truncations[satellite.name] = False
        
        return rewards, truncations

env = MARLGEOEnv(params)

time_step = 1
action_std_decay_freq = 10000
action_std_decay_rate = 0.05
update_freq = 1024
min_action_std = 0.05
episodes = 10000
steps_per_episode = 500 # little more than 2 revolutions (2 days)
low = np.array([0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

experiment = 22

writer = SummaryWriter(log_dir=f"runs/experiment_{experiment}")

# load models
# for i in range(len(env.satellites)):
#     env.satellites[i].ppo.load(f"model_geo_agent{i}.pth")

best_score = -1e7
start_episode = 1

for episode in range(start_episode, start_episode + episodes + 1):
    start_step = time_step
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    observations, _ = env.reset()
    for t in range(1, steps_per_episode + 1):
        try:

            if any(not satellite.has_fuel() for satellite in env.satellites):
                break

            # select actions with policies
            actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}

            # apply step
            observations, rewards, terminations, _, _ = env.step(actions)

            # save rewards and is_terminals
            for satellite in env.satellites:
                satellite.ppo.buffer.rewards.append(rewards[str(satellite)])
                satellite.ppo.buffer.is_terminals.append(terminations[str(satellite)])
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

            # decay action std of ouput action distribution
            if time_step % action_std_decay_freq == 0:
                for satellite in env.satellites:
                    if satellite.ppo.action_std > min_action_std:
                        satellite.ppo.decay_action_std(action_std_decay_rate, min_action_std)

            time_step += 1

            if terminations[env.agents[0]]:
                break
            
        except Exception as e:
            # if there's a propagation error, remove last values from each satellite agent buffer
            buffer_size = len(env.satellites[0].ppo.buffer.states)
            for satellite in env.satellites:
                if len(satellite.ppo.buffer.states) < buffer_size:
                    break
                satellite.ppo.buffer.states.pop()
                satellite.ppo.buffer.actions.pop()
                satellite.ppo.buffer.logprobs.pop()
            # traceback.print_exc()
            print("hit Earth")
            break

        if params["interface"]["show"]:
            env.render()

        if time_step % update_freq == 0:

            for satellite in env.satellites:
                try:
                    satellite.ppo.update_gae()
                except Exception as e:
                    satellite.ppo.buffer.clear()
                    satellite.ppo.policy.eval()
                    satellite.ppo.policy_old.eval()
                    traceback.print_exc()

            # average critics
            critics = [satellite.ppo.policy.critic for satellite in env.satellites]
            avg_state_dict = {}
            for key in critics[0].state_dict():
                avg_state_dict[key] = torch.zeros_like(critics[0].state_dict()[key])
            for critic in critics:
                state_dict = critic.state_dict()
                for key in state_dict:
                    if state_dict[key].dtype == torch.long:
                        avg_state_dict[key] = state_dict[key].clone()
                    else:
                        avg_state_dict[key] += state_dict[key].float() / len(env.satellites)
            # update critics
            for satellite in env.satellites:
                satellite.ppo.policy.critic.load_state_dict(avg_state_dict)
                satellite.ppo.policy_old.critic.load_state_dict(avg_state_dict)

            # # average actors
            # actors = [satellite.ppo.policy.actor for satellite in env.satellites]
            # avg_state_dict = {}
            # for key in actors[0].state_dict():
            #     avg_state_dict[key] = torch.zeros_like(actors[0].state_dict()[key])
            # for actor in actors:
            #     state_dict = actor.state_dict()
            #     for key in state_dict:
            #         if state_dict[key].dtype == torch.long:
            #             avg_state_dict[key] = state_dict[key].clone()
            #         else:
            #             avg_state_dict[key] += state_dict[key].float() / len(env.satellites)
            # # update actors
            # for satellite in env.satellites:
            #     satellite.ppo.policy.actor.load_state_dict(avg_state_dict)
            #     satellite.ppo.policy_old.actor.load_state_dict(avg_state_dict)

            # print("AFTER AVERAGE")
            # for satellite in env.satellites:
            #     print(satellite.ppo.policy.critic.state_dict()['0.weight'])

    # set the last experience termination flag to True
    for satellite in env.satellites:
        if len(satellite.ppo.buffer.is_terminals) > 0:
            satellite.ppo.buffer.is_terminals[-1] = True

    # show scores at the end of episode
    final_score = 0
    # print(f"######### Episode {episode} #########")
    for satellite in env.satellites:
        score = current_ep_rewards[str(satellite)]
        final_score += score
    final_score /= len(env.satellites)
    current_ep_rewards["average"] = final_score

    writer.add_scalars("Reward/Followers", current_ep_rewards, episode)
    writer.add_scalars(f"Mass/Fuel", {satellite.name: satellite.get_fuel() for satellite in env.satellites}, episode)
    # for satellite in env.satellites:
    #     satellite_stds = torch.exp(satellite.ppo.policy.log_std).tolist()
    #     writer.add_scalars(f"Std/Std_{satellite.name}", {'M': satellite_stds[0], 'theta': satellite_stds[1]}, episode)

    if final_score > best_score:
        print(f'>>>>>>>> Best score of {final_score} found in episode {episode}. Saving the models.')
        best_score = final_score
        for i in range(len(env.satellites)):
            satellite = env.satellites[i]
            satellite.ppo.save(f"model_geo_agent{i}.pth")

writer.close()