import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

import orekit
from orekit.pyhelpers import setup_orekit_curdir

orekit.initVM()
setup_orekit_curdir("orekit-data")

from constants import ITRF, EARTH_FLATTENING, EARTH_RADIUS, MU, INERTIAL_FRAME
from bodies import Body, Satellite
from pettingzoo import ParallelEnv, AECEnv
from forces import ThirdBodyForce, SolarRadiationForce, DragForce
import gymnasium as gym
import numpy as np
import torch
import os
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
from interface.interface import Interface

from org.orekit.bodies import OneAxisEllipsoid
from org.orekit.forces.gravity.potential import GravityFieldFactory
from org.orekit.forces.gravity import HolmesFeatherstoneAttractionModel, NewtonianAttraction, LenseThirringRelativity
from org.orekit.propagation import PropagatorsParallelizer
from org.orekit.propagation.sampling import PythonMultiSatFixedStepHandler
from org.orekit.time import AbsoluteDate, TimeScalesFactory

from java.util import Arrays

class OrbitZoo(ParallelEnv):
    """
    A class that uses high-fidelity orbital dynamics and works as a high-level MARL framework to develop space missions, focused on LEO.
    If that's not enough for you, it also has an awesome and lightweight interface.

    Example for 3D visualization of a system moving in time:
    ```
    from env import OrbitZoo    # don't forget to import this class (duh)
    params = {...}              # define the system you want to see through a JSON object (from characteristics of bodies, to several kinds of forces and interface customization)
    env = OrbitZoo(params)      # initialize the environment
    while True:                 
        env.step()              # propagate the system (if you comment this line, you can see the system frozen in time!)
        env.render()            # show a frame of the awesome interface
    ```

    To reset the environment to its initial state, guess what, you have a ```reset(seed=None)``` function.

    To get information regarding the current state of a body, just get the body instance by its 'name':
    ```
    body = env.get_body('name')
    position = body.get_cartesian_position()        # returns the current cartesian position (x, y, z), in m
    velocity = body.get_cartesian_velocity()        # returns the current cartesian velocity (vx, vy, vz), in m/s
    elements = body.get_equinoctial_elements()      # returns the current equinoctial parameters (a, ex, ey, hx, hy, lm)
    covariance = body.get_covariance_matrix()       # returns the current cartesian covariance matrix (6x6)
    # (...) there are many more methods you can explore
    ```
    """
    metadata = {
        "name": "orbitzoo_v0",
    }

    def __init__(self, params):

        self.agents = [satellite["name"] for satellite in params["satellites"]] if "satellites" in params else {}
        self.possible_agents = self.agents[:] if "satellites" in params else {}

        # if not "drifters" in params and not "satellites" in params:
        #     raise ValueError(f"The system does not have 'drifters' nor 'satellites'. Add at least one.")

        if "satellites" in params:

            # initialize device for training
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

            # define observation and action spaces
            self.observation_spaces = {}
            self.action_spaces = {}
            for satellite in params["satellites"]:
                state_dim_actor = satellite["agent"]["state_dim_actor"]
                action_space = satellite["agent"]["action_space"]
                self.observation_spaces[satellite["name"]] = gym.spaces.Box(low=np.full(state_dim_actor, -np.inf), high=np.full(state_dim_actor, np.inf), dtype=np.float64)
                self.action_spaces[satellite["name"]] = gym.spaces.Box(low=np.zeros(len(action_space)), high=np.array(action_space), dtype=np.float64)

            # add needed fields to create agent (PPO) for each satellite
            for i in range(len(params["satellites"])):
                params["satellites"][i]["agent"]["action_dim"] = self.action_spaces[self.agents[i]].shape[0]
                params["satellites"][i]["agent"]["device"] = self.device

        # add initial_date to each body
        if "initial_date" in params:
            date = params["initial_date"]
            self.initial_date = AbsoluteDate(int(date["year"]), int(date["month"]), int(date["day"]), int(date["hour"]), int(date["minute"]), float(date["second"]), TimeScalesFactory.getUTC())
        else:
            self.initial_date = AbsoluteDate()
        if "drifters" in params:
            for i in range(len(params["drifters"])):
                params["drifters"][i]["initial_date"] = self.initial_date
        if "satellites" in params:
            for i in range(len(params["satellites"])):
                params["satellites"][i]["initial_date"] = self.initial_date

        # create bodies
        self.create_bodies(params)

        # check if all bodies names are unique
        body_names = [body.name for body in self.drifters + self.satellites]
        if len(body_names) != len(set(body_names)):
            raise ValueError(f"There are bodies with the same name. Attribute 'name' should be a unique identifier.")

        # gravity model
        earth = OneAxisEllipsoid(EARTH_RADIUS, EARTH_FLATTENING, ITRF)
        if params["forces"]["gravity_model"] == 'HolmesFeatherstone':
            gravity_field = GravityFieldFactory.getNormalizedProvider(8, 8)
            gravity_model = HolmesFeatherstoneAttractionModel(earth.getBodyFrame(), gravity_field)
        elif params["forces"]["gravity_model"] == 'Newtonian':
            gravity_model = NewtonianAttraction(MU)
        elif params["forces"]["gravity_model"] == 'LenseThirring':
            gravity_model = LenseThirringRelativity(MU, INERTIAL_FRAME)
        else:
            raise ValueError(f"The provided 'gravity_model' is not supported. Available models are: 'HolmesFeatherstone', 'Newtonian' and 'LenseThirring'.")

        forces = {body.name: [gravity_model] for body in self.drifters + self.satellites}

        # third bodies
        if params["forces"]["third_bodies"]["active"]:
            for third_body_name in params["forces"]["third_bodies"]["bodies"]:
                third_body_force = ThirdBodyForce(third_body_name)
                for body_name in forces:
                    forces[body_name].append(third_body_force)

        # solar radiation pressure
        if params["forces"]["solar_radiation_pressure"]["active"]:
            for body_name in params["forces"]["solar_radiation_pressure"]["reflection_coefficients"]:
                coefficient = params["forces"]["solar_radiation_pressure"]["reflection_coefficients"][body_name]
                body = self.get_body(body_name)
                solar_radiation_force = SolarRadiationForce(earth, body.surface_area, coefficient, body.is_starlink)
                forces[body_name].append(solar_radiation_force)

        # drag
        if params["forces"]["drag"]["active"]:
            for body_name in params["forces"]["drag"]["drag_coefficients"]:
                coefficient = params["forces"]["drag"]["drag_coefficients"][body_name]
                body = self.get_body(body_name)
                drag_force = DragForce(earth, body.surface_area, coefficient, body.is_starlink)
                forces[body_name].append(drag_force)

        self.parallel_propagation = params["forces"]["parallel_propagation"] if "parallel_propagation" in params["forces"] else False
        self.delta_t = params["delta_t"] if "delta_t" in params else 500.0

        for body in self.drifters + self.satellites:
            body.forces = forces[body.name]
            body.delta_t = self.delta_t

        for body in self.drifters + self.satellites:
            body.reset()

        # initialize interface if requested
        self.has_interface = params["interface"]["show"]
        if self.has_interface:
            self.interface = Interface(self.drifters + self.satellites, params["interface"])

        # self.reset()

    def get_body(self, name):
        """
        Returns the body instance that has the given 'name'.
        """
        for body in self.drifters + self.satellites:
            if body.name == name:
                return body
        return None
    
    def create_bodies(self, params):
        """
        Create all bodies instances, split in 'drifters' and 'satellites' lists.
        This function is useful for Reinforcement Learning, if you happen to create new implementations for Body and Satellite classes.
        """
        self.drifters = [Body(body_params) for body_params in params["drifters"]] if "drifters" in params else []
        self.satellites = [Satellite(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None):

        for body in self.drifters + self.satellites:
            body.reset(seed)

        if self.parallel_propagation:
            self.current_date = self.initial_date
            propagators = Arrays.asList([body.propagator for body in self.drifters + self.satellites])
            prop_handler = PropagationHandler(propagators, self.initial_date, self.delta_t)
            self.propagator = PropagatorsParallelizer(propagators, self.delta_t, prop_handler)

        if self.has_interface:
            self.interface.reset()

        observations = {a: tuple() for a in self.agents}

        # Get dummy infos. Necessary for proper parallel_to_aec conversion
        infos = {a: {} for a in self.agents}

        return observations, infos

    def step(self, actions = None):

        if self.parallel_propagation:
            for satellite in self.satellites:
                satellite.change_thrust(actions[satellite.name])
            initial_date = self.current_date
            self.current_date = initial_date.shiftedBy(self.delta_t)
            states = self.propagator.propagate(initial_date, self.current_date)
            states = list(states)
            bodies = self.drifters + self.satellites
            for i in range(len(states)):
                bodies[i].current_state = states[i]
                bodies[i].current_date = self.current_date

        else:
            for drifter in self.drifters:
                drifter.step()
            for satellite in self.satellites:
                satellite.step(actions[satellite.name])

        observations = {a: () for a in self.agents}
        rewards = {a: 0 for a in self.agents}
        terminations = {a: False for a in self.agents}
        truncations = {a: False for a in self.agents}
        infos = {a: {} for a in self.agents}

        return observations, rewards, terminations, truncations, infos
    
    def step_back(self, actions = None):
        for body in self.drifters + self.satellites:
            body.step_back()

    def rewards(self, states_before, actions, states_after):
        """
        Optional function used for Reinforcement Learning (you can also change the input parameters).
        """
        pass

    def render(self):
        self.interface.frame()

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]
    
class PropagationHandler(PythonMultiSatFixedStepHandler):

    def init(self, states0, t, step):
        pass
    
    def handleStep(self, states):
        pass
    
    def finish(self, finalStates):
        pass
