import numpy as np
from env import OrbitZoo
from bodies import Satellite
from rl_algorithms.ppo_std import PPO
import traceback
import torch
from torch.utils.tensorboard import SummaryWriter
import copy

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 6928137.0, "y": 0.0, "z": 0.0, "x_dot": 0.0, "y_dot": 7585.088535158763, "z_dot": 0.0},
             "initial_state_uncertainty": {"x": 1e-15, "y": 1e-15, "z": 1e-15, "x_dot": 1e-15, "y_dot": 1e-15, "z_dot": 1e-15},
             "initial_mass": 25.0,
             "fuel_mass": 75.0,
             "isp": 0.0067,
             "radius": 16.8,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.0001,
                "lr_critic": 0.001,
                "gae_lambda": 0.95,
                "epochs": 5,
                "gamma": 0.99,
                "clip": 0.03,
                "action_std_init": 0.5,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [0.04 / 50, 2 * np.pi / 6],
             }},
        ],
        "delta_t": 1.0,
        "forces": {
            "gravity_model":  "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 0.5,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    "agent": 2.123,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "zoom": 1.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 50,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 5000,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (255, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 70,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                {"a": 550.0e3, "e": 0.00001, "i": 0.00001, "pa": 0.0, "raan": 0.0, "color": (0, 255, 0)},
            ],
        }
    }

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
        self.thrust_mag = 0.0
        self.thrust_theta = 0.0

    def reset(self, seed=None):
        super().reset(seed)
        self.thrust_mag = 0.0
        self.thrust_theta = 0.0
        
    def get_state(self):
        position = np.array(self.get_cartesian_position())
        velocity = np.array(self.get_cartesian_velocity())
        pos_diff = np.abs(np.linalg.norm(position) - 6928137.0)
        vel_diff = np.abs(np.linalg.norm(velocity) - 7585.088535158763)
        return list(position[:2] / 6928137.0) + list(velocity[:2] / 7585.088535158763) + [pos_diff, vel_diff, self.thrust_theta, self.thrust_mag]

class HerreraEnvironment(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)

    def create_bodies(self, params):
        self.drifters = []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None):
        
        super().reset(seed)

        self.step_counter = 0

        observations = {satellite.name: satellite.get_state() for satellite in self.satellites}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        self.step_counter += 1

        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1], [1,1]) for satellite in self.satellites}
        scaled_actions = {
            str(satellite): list(low + ((clipped_actions[satellite.name] + 1) / 2) * (high - low)) + [0]
            for satellite in self.satellites
        }

        # print(scaled_actions)

        for satellite in self.satellites:
            action = scaled_actions[str(satellite)]
            satellite.thrust_mag += action[0] * 0.04 - 0.02
            satellite.thrust_mag = np.clip(satellite.thrust_mag, 0.0, 1.0)
            satellite.thrust_theta += action[1] * np.pi / 3 - np.pi / 6
            satellite.thrust_theta -= (2 * np.pi) * np.floor((satellite.thrust_theta + np.pi) * (1 / (2 * np.pi)))

        observations = {satellite.name: satellite.step(np.array([sat.thrust_mag, sat.thrust_theta, 0.0])) for sat in self.satellites}
        rewards, terminations = self.rewards(observations)

        truncations = {satellite.name: False for satellite in self.satellites}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, rewards, terminations, truncations, infos

    def rewards(self, observations):

        rewards = {}
        terminations = {}

        for satellite in self.satellites:
            observation = observations[str(satellite)]
            termination = observation[4] > 1.0 or not satellite.has_fuel()
            reward = 0 if termination else (self.step_counter / 800) + 0.5
            rewards[str(satellite)] = reward
            terminations[str(satellite)] = termination
        
        return rewards, terminations

params_eval = copy.deepcopy(params)
params_eval["satellites"][0]["agent"]["action_std_init"] = 0.01

env = HerreraEnvironment(params)
env_eval = HerreraEnvironment(params_eval)

# env.satellites[0].ppo.load("model_checkpoint.pth")

time_step = 1
action_std_decay_freq = 10000
action_std_decay_rate = 0.05
update_freq = 800
min_action_std = 0.5
episodes = 10000
steps_per_episode = 800
low = np.array([0.0, 0.0])
high = np.array([1.0, 1.0])
best_score = 0
best_score_eval = 0

# agent = env.get_body('agent')
# env.reset()
# for t in range(1, steps_per_episode + 1):
#     altitude = agent.get_altitude()
#     if altitude > 6928137.0 + 1 or altitude < 6928137.0 - 1:
#         print(f'threshold hit at step {t}')
#         break
#     env.step({'agent': None})
    # print(f'{t}: {agent.get_fuel()}')
    # if not agent.has_fuel():
    #     print(f'no fuel at step {t}')
    #     break
    # env.step({'agent': np.array([0.04, 0, 0])})
    # env.render()

experiment = 4
writer = SummaryWriter(log_dir=f"runs/herrera_{experiment}")

for episode in range(1, episodes + 1):
    start_step = time_step
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    observations, _ = env.reset(42)
    loss = 0
    for t in range(1, steps_per_episode + 1):
        try:

            # select actions with policies
            actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}
            # actions = {'agent': np.array([1.0, -1.0])}
            # print(actions)

            # apply step
            observations, rewards, terminations, _, _ = env.step(actions)
            
            # save rewards and is_terminals
            for satellite in env.satellites:
                satellite.ppo.buffer.rewards.append(rewards[str(satellite)])
                satellite.ppo.buffer.is_terminals.append(terminations[str(satellite)])
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

            # decay action std of ouput action distribution
            # if time_step % action_std_decay_freq == 0 and env.satellites[0].ppo.action_std > min_action_std:
            #     # print("Decaying std...")
            #     for satellite in env.satellites:
            #         satellite.ppo.decay_action_std(action_std_decay_rate, min_action_std)

            time_step += 1

            # evaluate the model every 5000 training steps
            if time_step % 5000 == 0:
                # get current model
                satellite = env.satellites[0]
                satellite_eval = env_eval.satellites[0]
                satellite_eval.ppo.policy.load_state_dict(satellite.ppo.policy.state_dict())
                satellite_eval.ppo.policy_old.load_state_dict(satellite.ppo.policy_old.state_dict())
                scores = []
                # for 100 episodes
                for _ in range(100):
                    observations_eval, _ = env_eval.reset(42)
                    score = 0
                    for _ in range(1, steps_per_episode + 1):
                        actions_eval = {str(satellite): satellite.ppo.select_action(observations_eval[str(satellite)]) for satellite in env_eval.satellites}
                        observations_eval, rewards_eval, terminations_eval, _, _ = env_eval.step(actions_eval)
                        score += rewards_eval[str(satellite_eval)]
                        if any(terminations_eval[str(satellite)] for satellite in env_eval.satellites):
                            break
                    satellite_eval.ppo.buffer.clear()
                    scores.append(score)
                eval_mean = np.mean(scores)
                eval_std = np.std(scores)
                print(f'EVAL >>> mean: {eval_mean}, std: {eval_std}')
                if eval_mean > best_score_eval:
                    print(f'EVAL >>>>>>>> Best score of {eval_mean} found. Saving the model.')
                    best_score_eval = eval_mean
                    satellite.ppo.save(f"model_checkpoint_eval.pth")


            if any(terminations[str(satellite)] for satellite in env.satellites):
                break

        except Exception as e:
            # if there's a propagation error, remove last values from each satellite agent buffer
            for satellite in env.satellites:
                satellite.ppo.buffer.states.pop()
                satellite.ppo.buffer.actions.pop()
                satellite.ppo.buffer.logprobs.pop()
            traceback.print_exc()
            break

        if params["interface"]["show"]:
            env.render()

        if time_step % update_freq == 0:
            # print("Updating...")
            for satellite in env.satellites:
                loss += satellite.ppo.update_gae()

    # set the last experience termination flag to True
    if len(env.satellites[0].ppo.buffer.is_terminals) > 0:
        env.satellites[0].ppo.buffer.is_terminals[-1] = True

    # show scores at the end of episode
    for satellite in env.satellites:
        score = current_ep_rewards[str(satellite)]
        # print(f'Episode {episode}: {score}, Fuel: {satellite.get_fuel()}, Std: {satellite.ppo.action_std}')
        # print(f'Episode {episode}: {score}, Fuel: {satellite.get_fuel()}, Std: {np.exp(satellite.ppo.policy.log_std.cpu().detach().numpy())}, steps: {t}')
        if score > best_score:
            print(f'>>>>>>>> Best score of {score} found. Saving the model.')
            best_score = score
            satellite.ppo.save(f"model_checkpoint.pth")

    writer.add_scalars("Reward", current_ep_rewards, episode)
    writer.add_scalars(f"Fuel", {satellite.name: satellite.get_fuel() for satellite in env.satellites}, episode)
    writer.add_scalars(f"Steps", {satellite.name: t for satellite in env.satellites}, episode)
    writer.add_scalars(f"Loss", {satellite.name: loss for satellite in env.satellites}, episode)
    # writer.add_scalars(f"Std", {satellite.name: satellite.ppo.action_std for satellite in env.satellites}, episode)
    for satellite in env.satellites:
        satellite_stds = torch.exp(satellite.ppo.policy.log_std).tolist()
        writer.add_scalars(f"Std", {'M': satellite_stds[0], 'theta': satellite_stds[1]}, episode)

writer.close()