import numpy as np
from env import OrbitZoo
from bodies import Satellite
from rl_algorithms.ddpg import DDPG
from constants import EARTH_RADIUS
import traceback
from torch.utils.tensorboard import SummaryWriter

params = {
        "satellites": [
            {"name": "follower_1",
             "initial_state": {"x": 6129800.048013737, "y": 7280790.331579311, "z": 415150.31877402193, "x_dot": -5281.077303515858, "y_dot": 4683.916455224767, "z_dot": 543.1013245770663},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 500.0,
             "fuel_mass": 150.0,
             "isp": 3100.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.00001,
                "lr_critic": 0.0001,
                "gamma": 0.99,
                "tau": 0.01,
                "memory_capacity": 10_000,
                "batch_size": 256,
                "K_epochs": 1,
                "policy_delay": 10,
                "state_dim_actor": 7,
                "state_dim_critic": 7,
                "action_space": [0.6, np.pi, 2*np.pi],
             }},
        ],
        "delta_t": 500.0,
        "forces": {
            "gravity_model": "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 0.5,
                }
            },
            "drag": {
                "active": False,
                "drag_coefficients": {
                    "agent": 0.5,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "zoom": 3.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 300,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (0, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 500,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (0, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 50,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                {"a": 6300.0e3, "e": 0.23, "i": 5.3, "pa": 24.0, "raan": 24.0, "color": (0, 255, 0)},
            ],
        }
    }

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ddpg = DDPG(params["agent"]["device"],
                         params["agent"]["state_dim_actor"],
                         params["agent"]["state_dim_critic"],
                         params["agent"]["action_dim"],
                         params["agent"]["lr_actor"],
                         params["agent"]["lr_critic"],
                         params["agent"]["gamma"],
                         params["agent"]["tau"],
                         params["agent"]["memory_capacity"],
                         params["agent"]["batch_size"],
                         params["agent"]["K_epochs"])

        # self.ddpg = TD3(params["agent"]["device"],
        #                  params["agent"]["state_dim_actor"],
        #                  params["agent"]["state_dim_critic"],
        #                  params["agent"]["action_dim"],
        #                  params["agent"]["lr_actor"],
        #                  params["agent"]["lr_critic"],
        #                  params["agent"]["gamma"],
        #                  params["agent"]["tau"],
        #                  params["agent"]["memory_capacity"],
        #                  params["agent"]["batch_size"],
        #                  params["agent"]["K_epochs"],
        #                  params["agent"]["policy_delay"])
        
    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]

class KolosaEnvironment(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        self.target = [6300.0e3 + EARTH_RADIUS, 0.1539, 0.1709, 0.0423, 0.0188]
        self.tolerances = [10.0e3, 0.01, 0.01, 0.001, 0.001]

    def create_bodies(self, params):
        self.drifters = []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None):
        
        observations = {str(body): body.reset(seed) for body in self.drifters + self.satellites}

        for satellite in self.satellites:
            satellite.ddpg.ou_noise.reset()

        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        # scale actions
        scaled_actions = {str(satellite): ((actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        # print(scaled_actions)

        # step
        observations = {satellite.name: satellite.step(scaled_actions[satellite.name]) for satellite in self.satellites}

        # compute reward
        rewards, terminations = self.rewards(observations)
        # print(rewards)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos
    
    def rewards(self, observations):

        rewards = {}
        terminations = {}

        for satellite in self.satellites:

            rewards[str(satellite)] = 0

            observation = observations[str(satellite)]
            target = self.target
            tolerances = self.tolerances

            alphas = np.array([1, 1, 1, 10, 10])

            diffs = np.array([
                np.sqrt((target[0] - observation[0])**2),
                np.sqrt((target[1] - observation[1])**2),
                np.sqrt((target[2] - observation[2])**2),
                np.sqrt((target[3] - observation[3])**2),
                np.sqrt((target[4] - observation[4])**2)
            ])

            if diffs[0] <= tolerances[0] and diffs[1] <= tolerances[1] and diffs[2] <= tolerances[2] and diffs[3] <= tolerances[3] and diffs[4] <= tolerances[4]:
                rewards[str(satellite)] += 1
                terminations[str(satellite)] = True
                print('hit')
            else: 
                terminations[str(satellite)] = False

            diffs[0] /= target[0]
            rewards[str(satellite)] += -np.dot(alphas, diffs)

        return rewards, terminations

env = KolosaEnvironment(params)

time_step = 1
update_freq = 1024
min_action_std = 0.05
episodes = 10000
steps_per_episode = 692
low = np.array([0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

experiment = 18

writer = SummaryWriter(log_dir=f"runs/experiment_{experiment}")

# load model
# for i in range(len(env.satellites)):
#     env.satellites[i].ddpg.load(f"model_kolosa")

best_score = -1e7
start_episode = 1

for episode in range(start_episode, start_episode + episodes + 1):
    start_step = time_step
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    current_ep_value_loss = {str(satellite): 0 for satellite in env.satellites}
    current_ep_policy_loss = {str(satellite): 0 for satellite in env.satellites}
    observations, _ = env.reset(42)
    data = []
    
    for t in range(1, steps_per_episode + 1):
        try:

            if any(not satellite.has_fuel() for satellite in env.satellites):
                break

            # select actions with policies and clip them
            actions = {str(satellite): np.clip(satellite.ddpg.select_action(observations[str(satellite)], satellite.ddpg.ou_noise), [-1,-1,-1], [1,1,1]) for satellite in env.satellites}
            # print(actions)

            # apply step
            next_observations, rewards, terminations, _, _ = env.step(actions)

            # print(f"next_observations: {next_observations}")

            for satellite in env.satellites:
                mask = 1 if terminations[satellite.name] else 0
                reward = rewards[satellite.name]
                state = observations[satellite.name]
                next_state = next_observations[satellite.name]
                action = actions[satellite.name]
                satellite.ddpg.memory.add((state, action, reward, next_state, mask))
                # satellite.ddpg.memory.push(state, action, mask, next_state, reward)

            observations = next_observations
            # print(f"observations: {observations}")
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

            time_step += 1

            if any(terminations[str(satellite)] for satellite in env.satellites):
                break
            
        except Exception as e:
            traceback.print_exc()
            break

        if params["interface"]["show"]:
            env.render()

        for satellite in env.satellites:
            if len(satellite.ddpg.memory) > satellite.ddpg.batch_size:
                value_loss, policy_loss = satellite.ddpg.update()
                current_ep_value_loss[str(satellite)] += value_loss
                current_ep_policy_loss[str(satellite)] += policy_loss

        # if time_step % update_freq == 0:
        #     for satellite in env.satellites:
        #         if len(satellite.ddpg.memory) > satellite.ddpg.batch_size:
        #             value_loss, policy_loss = satellite.ddpg.update()
        #             current_ep_value_loss[str(satellite)] += value_loss
        #             current_ep_policy_loss[str(satellite)] += policy_loss

    # show scores at the end of episode
    final_score = 0
    # print(f"######### Episode {episode} #########")
    for satellite in env.satellites:
        score = current_ep_rewards[str(satellite)]
        final_score += score

    current_ep_rewards["average"] = final_score

    writer.add_scalars("Reward/Followers", current_ep_rewards, episode)
    writer.add_scalars(f"Mass/Fuel", {satellite.name: satellite.get_fuel() for satellite in env.satellites}, episode)
    writer.add_scalars("Loss/Value", current_ep_value_loss, episode)
    writer.add_scalars("Loss/Policy", current_ep_policy_loss, episode)

    if final_score > best_score:
        print(f'>>>>>>>> Best score of {final_score} found in episode {episode}. Saving the models.')
        best_score = final_score
        for i in range(len(env.satellites)):
            satellite = env.satellites[i]
            satellite.ddpg.save(f"model_kolosa")

writer.close()