import numpy as np
from env import OrbitZoo
from bodies import Satellite, Body
from rl_algorithms.ppo import PPO
import traceback

params = {
        "drifters": [
            {"name": "leader",
             # "initial_state": {"x": 40135560.35763372, "y": 23106845.90661858, "z": 1381125.2706987557, "x_dot": -1465.0065468674454, "y_dot": 2531.305585514847, "z_dot": 448.626521969306},
             "initial_state": {"x": 28400575.47747919, "y": 16350780.090518527, "z": 977306.7111762252, "x_dot": -1774.918911925119, "y_dot": 3180.7991992443126, "z_dot": 301.02112008026137},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 5.0,
             "radius": 5.0,
             "save_steps_info": False,
             }
        ],
        "satellites": [
            {"name": "follower",
             # "initial_state": {"x": 12786485.356547935, "y": 7361435.699122934, "z": 440002.27957423026, "x_dot": -2645.248605885859, "y_dot": 4740.500870700536, "z_dot": 448.626521969306},
             #"initial_state": {"x": -5457131.054736777, "y": 14915325.387999892, "z": 2197545.5322520277, "x_dot": -5535.291222985718, "y_dot": 622.1106458722226, "z_dot": 221.19056105580083},
             # "initial_state": {"x": 11216482.8099675, "y": 1970242.5416199851, "z": -172373.8867882523, "x_dot": -1026.7607154630223, "y_dot": 5800.89091779048, "z_dot": -507.51219365715974},
             # "initial_state": {"x": 16936004.262182973, "y": -20186618.443412434, "z": 295619.0540209736, "x_dot": 2980.3614622711616, "y_dot": 2501.201148816514, "z_dot": 52.02240283350062},
             "initial_state": {"x": 12786485.356547935, "y": 7361435.699122934, "z": 440002.27957423026, "x_dot": -2645.248605885859, "y_dot": 4740.500870700536, "z_dot": 448.626521969306},
             "initial_state_uncertainty": {"x": 0.0001, "y": 0.0001, "z": 0.0001, "x_dot": 0.0001, "y_dot": 0.0001, "z_dot": 0.0001},
             "initial_mass": 500.0,
             "fuel_mass": 150.0,
             "isp": 3000.0,
             "radius": 5.0,
             "save_steps_info": False,
             "agent": {
                "lr_actor": 0.00001,
                "lr_critic": 0.0001,
                "gae_lambda": 0.95,
                "epochs": 1,
                "gamma": 0.99,
                "clip": 0.1,
                "action_std_init": 0.01,
                "state_dim_actor": 8,
                "state_dim_critic": 8,
                "action_space": [30, np.pi, 2*np.pi],
             }},
        ],
        "delta_t": float(500.0 / 7),
        "forces": {
            "gravity_model": "HolmesFeatherstone",
            "third_bodies": {
                "active": True,
                "bodies": ["SUN", "MOON"],
            },
            "solar_radiation_pressure": {
                "active": True,
                "reflection_coefficients": {
                    "follower": 5.0,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    "follower": 5.0,
                }
            }
        },
        "interface": {
            "show": True,
            "delay_ms": 0,
            "zoom": 3.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 300,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (0, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 3000,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (0, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 50,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            "orbits": [
                # {"a": 40787484.01233797, "e": 0.017809787119984637, "i": 0.15507728170115928, "pa": 0.06411968016437242, "raan": 0.3304370800422834, "color": (255, 0, 255)},
                # {"a": 5_500_000 + 11_530_000, "e": 0.01, "i": 5.0, "pa": 20.0, "raan": 20.0, "color": (0, 255, 0)},
                # {"a": 2030.0e3, "e": 0.01, "i": 30.0, "pa": 40.0, "raan": 20.0, "color": (0, 255, 255)},
                # {"a": 16030.0e3, "e": 0.7, "i": 0.0001, "pa": 20.0, "raan": 20.0, "color": (255, 0, 255)},
            ],
        }
    }

class Chaser(Satellite):

    def __init__(self, params):
        super().__init__(params)
        self.ppo = PPO(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       params["agent"]["state_dim_critic"],
                       params["agent"]["action_dim"],
                       params["agent"]["lr_actor"],
                       params["agent"]["lr_critic"],
                       params["agent"]["gae_lambda"],
                       params["agent"]["gamma"],
                       params["agent"]["epochs"],
                       params["agent"]["clip"],
                       True,
                       params["agent"]["action_std_init"])
        
    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
        
    def print_networks(self):
        print(self.ppo.policy.actor)
        print(self.ppo.policy.critic)

class ChaseEnvironment(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        self.tolerance = [1000.0, 0.01, 0.01, 0.001, 0.001]

    def create_bodies(self, params):
        self.drifters = [Body(body_params) for body_params in params["drifters"]] if "drifters" in params else []
        self.satellites = [Chaser(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None):
        
        for body in self.drifters + self.satellites:
            body.reset(seed)

        chaser_elements = self.satellites[0].get_equinoctial_position()
        target_anomaly = self.drifters[0].get_equinoctial_position()[5]
        chaser_fuel = self.satellites[0].get_fuel()
        observation = chaser_elements + [target_anomaly] + [chaser_fuel]
        # print(observation)

        self.reached_orbit = False

        observations = {'follower': observation}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None):

        for drifter in self.drifters:
            drifter.step()

        satellite = self.satellites[0]

        # scale actions
        clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1], [1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((clipped_actions[satellite.name] + 1) / 2) * high for satellite in env.satellites}

        # print(scaled_actions)

        # step
        states = {satellite.name: satellite.step(scaled_actions[satellite.name])}

        # build observation
        chaser_elements = self.satellites[0].get_equinoctial_position()
        target_anomaly = self.drifters[0].get_equinoctial_position()[5]
        chaser_fuel = self.satellites[0].get_fuel()
        observation = chaser_elements + [target_anomaly] + [chaser_fuel]
        observations = {'follower': observation}
        # print(states)

        # compute reward
        rewards, terminations = self.rewards()

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos
    
    def rewards(self):

        chaser = self.satellites[0].get_equinoctial_position()
        target = self.drifters[0].get_equinoctial_position()
        chaser_anomaly = chaser[5] % 2*np.pi
        target_anomaly = target[5] % 2*np.pi

        a_diff = abs(target[0] - chaser[0])
        ex_diff = abs(target[1] - chaser[1])
        ey_diff = abs(target[2] - chaser[2])
        hx_diff = abs(target[3] - chaser[3])
        hy_diff = abs(target[4] - chaser[4])

        r_a = a_diff / target[0]
        r_ex = ex_diff
        r_ey = ey_diff
        r_hx = hx_diff
        r_hy = hy_diff
        r_lv = np.abs(np.arctan2(np.sin(chaser_anomaly - target_anomaly), np.cos(chaser_anomaly - target_anomaly)))
        # alpha_a = a_diff / target[0]
        # alpha_ex = ex_diff / target[1]
        # alpha_ey = ey_diff / target[2]
        # alpha_hx = hx_diff / target[3]
        # alpha_hy = hy_diff / target[4]
        alpha_a = 1000
        alpha_ex = 1
        alpha_ey = 1
        alpha_hx = 10
        alpha_hy = 10
        alpha_lv = 0.001

        # if within_tolerance:
        #     alpha_lv = 0.01
        #     if not self.reached_orbit:
        #         print("Reached orbit")
        #         self.reached_orbit = True
        #     r_lv = np.abs(np.arctan2(np.sin(chaser_anomaly - target_anomaly), np.cos(chaser_anomaly - target_anomaly)))
        #     alpha_lv = 0.01
        #     reward = -(alpha_lv*r_lv)

        #distance = Body.get_distance(self.satellites[0], self.drifters[0])
        #r_distance = (distance - 1000)**2 / 1e15

        # distance_penalty = Body.get_distance(self.satellites[0], self.drifters[0]) * 1e-8
        # print(f'{r_a},{r_ex},{r_ey},{r_hx},{r_hy},{r_lv}')
        reward = -(alpha_a*r_a + alpha_ex*r_ex + alpha_ey*r_ey + alpha_hx*r_hx + alpha_hy*r_hy + alpha_lv*r_lv)
        #reward = -(alpha_a*r_a + alpha_ex*r_ex + alpha_ey*r_ey + alpha_hx*r_hx + alpha_hy*r_hy + r_distance)

        return {'follower': reward / 1000}, {'follower': False}

env = ChaseEnvironment(params)

time_step = 1
action_std_decay_freq = 10000
action_std_decay_rate = 0.05
update_freq = 4096
min_action_std = 0.05
episodes = 10000
steps_per_episode = 10000
low = np.array([0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

# load model
env.satellites[0].ppo.load(".\\missions\\chase_target\\model_chase.pth")
best_score = -1e7
start_episode = 1

start_step = time_step
current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
observations, _ = env.reset(42)

data = []

for t in range(1, steps_per_episode + 1):
    try:

        if not env.satellites[0].has_fuel():
            actions = {str(satellite): np.array([-1.0, 0.0, 0.0]) for satellite in env.satellites}
        else:
            actions = {str(satellite): satellite.ppo.select_action(observations[str(satellite)]) for satellite in env.satellites}

        # apply step
        observations, rewards, terminations, _, _ = env.step(actions)
        if terminations[env.agents[0]]:
            break

        elements_follower = env.satellites[0].get_equinoctial_position()
        elements_leader = env.drifters[0].get_equinoctial_position()
        
    except Exception as e:
        # if there's a propagation error, remove last values from each satellite agent buffer
        for satellite in env.satellites:
            satellite.ppo.buffer.states.pop()
            satellite.ppo.buffer.actions.pop()
            satellite.ppo.buffer.logprobs.pop()
        traceback.print_exc()
        break

    if params["interface"]["show"]:
        env.render()