import numpy as np
from env import OrbitZoo
from bodies import Satellite, Body
from rl_algorithms.dqn import DQN
from constants import EARTH_RADIUS, INERTIAL_FRAME, MU, ATTITUDE
from math import radians
import traceback

from org.orekit.time import AbsoluteDate
from org.orekit.orbits import KeplerianOrbit, PositionAngleType, OrbitType
from org.hipparchus.geometry.euclidean.threed import Vector3D
from org.orekit.forces.gravity import NewtonianAttraction
from org.orekit.propagation.numerical import NumericalPropagator
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.hipparchus.linear import Array2DRowRealMatrix
from org.orekit.propagation import StateCovariance, StateCovarianceMatrixProvider
from orekit import JArray_double

params = {
        "satellites": [
            {"name": "agent",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": -5320.577430007447, "y_dot": 4465.13261069736, "z_dot": 526.2965261179082},
             "initial_state_uncertainty": {"x": 0.1, "y": 0.1, "z": 0.1, "x_dot": 0.1, "y_dot": 0.1, "z_dot": 0.1},
             "initial_mass": 250.0,
             "fuel_mass": 50.0,
             "isp": 3100.0,
             "radius": 10.0,
             "save_steps_info": False,
             "agent": {
                "lr": 0.00005,
                "epochs": 1,
                "gamma": 0.95,
                "tau": 0.001,
                "memory_capacity": 10_000,
                "batch_size": 256,
                "state_dim_actor": 14,
                "state_dim_critic": 14,
                "action_space": [5.0, np.pi, 2*np.pi, 1.0],
             }},
        ],
        "drifters": [
            {"name": "drifter",
             "initial_state": {"x": 5337709.428463124, "y": 6339969.149911649, "z": 361504.73969662545, "x_dot": 5320.577430007447, "y_dot": -4465.13261069736, "z_dot": -526.2965261179082},
             "initial_state_uncertainty": {"x": 0.1, "y": 0.1, "z": 0.1, "x_dot": 0.1, "y_dot": 0.1, "z_dot": 0.1},
             "initial_mass": 250.0,
             "radius": 5.0,
             "save_steps_info": False,},
        ],
        "delta_t": 1800.0,
        "forces": {
            "gravity_model": "Newtonian",
            "third_bodies": {
                "active": False,
                "bodies": ["SUN", "MOON", "JUPITER"],
            },
            "solar_radiation_pressure": {
                "active": False,
                "reflection_coefficients": {
                    "agent": 2.0,
                }
            },
            "drag": {
                "active": True,
                "drag_coefficients": {
                    "agent": 10.0,
                }
            }
        },
        "interface": {
            "show": False,
            "delay_ms": 0,
            "zoom": 1.0,
            "drifters": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_trail": True,
                "trail_last_steps": 300,
                "color_body": (255, 255, 255),
                "color_label": (255, 255, 255),
                "color_velocity": (0, 255, 255),
                "color_trail": (255, 255, 255),
            },
            "satellites": {
                "show": True,
                "show_label": True,
                "show_velocity": False,
                "show_thrust": True,
                "show_trail": True,
                "trail_last_steps": 500,
                "color_body": (255, 0, 0),
                "color_label": (255, 255, 255),
                "color_velocity": (0, 255, 255),
                "color_thrust": (0, 255, 0),
                "color_trail": (255, 0, 0),
            },
            "earth": {
                "show": True,
                "color": (0, 0, 255),
                "resolution": 50,
            },
            "equator_grid": {
                "show": False,
                "color": (30, 140, 200),
                "resolution": 10,
            },
            "timestamp": {
                "show": True,
            },
            # "orbits": [
            #     {"a": 2030.0e3, "e": 0.01, "i": 5.0, "pa": 20.0, "raan": 20.0, "color": (0, 255, 0)},
            #     {"a": 2030.0e3, "e": 0.01, "i": 30.0, "pa": 40.0, "raan": 20.0, "color": (0, 255, 255)},
            #     {"a": 16030.0e3, "e": 0.7, "i": 0.0001, "pa": 20.0, "raan": 20.0, "color": (255, 0, 255)},
            # ],
        }
    }

class Drifter(Body):

    def reset(self, seed=None):
        super().reset(seed)

        tolerances = NumericalPropagator.tolerances(60.0, self.current_state.getOrbit(), self.current_state.getOrbit().getType())
        integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
        integrator.setInitialStepSize(10.0)
        propagator = NumericalPropagator(integrator)
        propagator.setInitialState(self.current_state)
        propagator.setMu(MU)
        propagator.setOrbitType(self.orbit_type)
        propagator.setAttitudeProvider(ATTITUDE)
        propagator.addForceModel(NewtonianAttraction(MU))
        covariance_matrix = self.dist.covariance_matrix.detach().numpy().tolist()
        matrix = Array2DRowRealMatrix(6, 6)
        for i in range(6):
            matrix.setRow(i, covariance_matrix[i])
        initial_covariance = StateCovariance(matrix, self.current_date, INERTIAL_FRAME, OrbitType.CARTESIAN, PositionAngleType.MEAN)
        harvester = propagator.setupMatricesComputation("stm", None, None)
        self.simulation_covariance_provider = StateCovarianceMatrixProvider("covariance", "stm", harvester, initial_covariance)
        propagator.addAdditionalStateProvider(self.simulation_covariance_provider)
        self.simulation_propagator = propagator

        # simulation_propagator = self.propagator
        # simulation_propagator.removeForceModels()
        # simulation_propagator.addForceModel(NewtonianAttraction(MU))
        # self.simulation_propagator = simulation_propagator

    def step_simulation(self, seconds = None):
        if not seconds:
            seconds = self.delta_t
        self.simulated_current_state = self.simulation_propagator.propagate(self.current_date.shiftedBy(seconds))
        self.simulated_current_date = self.current_state.getDate()
        return self.get_state()

    def propagate_back(self, seconds = None):
        if seconds is None:
            seconds = self.delta_t
        self.current_state = self.propagator.propagate(self.current_date.shiftedBy(-seconds))
        self.current_date = self.current_state.getDate()
        return self.get_state()
    
    def get_simulated_covariance_matrix(self, state = None):
        """
        Get the covariance matrix relative to a state. If no state is provided, it corresponds to the current state of the body.
        """
        if self.current_date == self.initial_date:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        if not state:
            state = self.simulated_current_state
        try:
            covariance_matrix = self.simulation_covariance_provider.getStateCovariance(state).getMatrix()
        except:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        covariance_matrix = np.array([covariance_matrix.getRow(i) for i in range(6)], dtype=float)
        return covariance_matrix.tolist()

class Agent(Satellite):

    def __init__(self, params):
        super().__init__(params)

        self.dqn = DQN(params["agent"]["device"],
                       params["agent"]["state_dim_actor"],
                       7, # 6 possible directions + do nothing
                       params["agent"]["lr"],
                       params["agent"]["gamma"],
                       params["agent"]["tau"],
                       params["agent"]["memory_capacity"],
                       params["agent"]["batch_size"],
                       params["agent"]["epochs"])
        
    def reset(self, seed=None):
        super().reset(seed)

        tolerances = NumericalPropagator.tolerances(60.0, self.current_state.getOrbit(), self.current_state.getOrbit().getType())
        integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
        integrator.setInitialStepSize(10.0)
        propagator = NumericalPropagator(integrator)
        propagator.setInitialState(self.current_state)
        propagator.setMu(MU)
        propagator.setOrbitType(self.orbit_type)
        propagator.setAttitudeProvider(ATTITUDE)
        propagator.addForceModel(NewtonianAttraction(MU))
        covariance_matrix = self.dist.covariance_matrix.detach().numpy().tolist()
        matrix = Array2DRowRealMatrix(6, 6)
        for i in range(6):
            matrix.setRow(i, covariance_matrix[i])
        initial_covariance = StateCovariance(matrix, self.current_date, INERTIAL_FRAME, OrbitType.CARTESIAN, PositionAngleType.MEAN)
        harvester = propagator.setupMatricesComputation("stm", None, None)
        self.simulation_covariance_provider = StateCovarianceMatrixProvider("covariance", "stm", harvester, initial_covariance)
        propagator.addAdditionalStateProvider(self.simulation_covariance_provider)
        self.simulation_propagator = propagator

        # simulation_propagator = self.propagator
        # simulation_propagator.removeForceModels()
        # simulation_propagator.addForceModel(NewtonianAttraction(MU))
        # self.simulation_propagator = simulation_propagator

    def get_state(self):
        return self.get_equinoctial_position() + [self.get_fuel()]
    
    def propagate_back(self, seconds = None):
        if seconds is None:
            seconds = self.delta_t
        self.current_state = self.propagator.propagate(self.current_date.shiftedBy(-seconds))
        self.current_date = self.current_state.getDate()
        return self.get_state()

    def step_simulation(self, seconds = None):
        if not seconds:
            seconds = self.delta_t
        self.simulated_current_state = self.simulation_propagator.propagate(self.current_date.shiftedBy(seconds))
        self.simulated_current_date = self.current_state.getDate()
        return self.get_state()
    
    def get_simulated_covariance_matrix(self, state = None):
        """
        Get the covariance matrix relative to a state. If no state is provided, it corresponds to the current state of the body.
        """
        if self.current_date == self.initial_date:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        if not state:
            state = self.simulated_current_state
        try:
            covariance_matrix = self.simulation_covariance_provider.getStateCovariance(state).getMatrix()
        except:
            return self.dist.covariance_matrix.detach().numpy().tolist()
        covariance_matrix = np.array([covariance_matrix.getRow(i) for i in range(6)], dtype=float)
        return covariance_matrix.tolist()

class ColAvoidanceEnv(OrbitZoo):

    def __init__(self, params):
        super().__init__(params)
        target = [2_000_000, 0.01, 5.0, 20.0, 20.0, 10.0]
        orbit = KeplerianOrbit(target[0] + EARTH_RADIUS, target[1], radians(target[2]), radians(target[3]), radians(target[4]), radians(target[5]), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, PositionAngleType.TRUE, INERTIAL_FRAME, AbsoluteDate(), MU)
        self.target = [orbit.getA(), orbit.getEquinoctialEx(), orbit.getEquinoctialEy(), orbit.getHx(), orbit.getHy()]

    def create_bodies(self, params):
        self.drifters = [Drifter(body_params) for body_params in params["drifters"]] if "drifters" in params else []
        self.satellites = [Agent(body_params) for body_params in params["satellites"]]  if "satellites" in params else []

    def reset(self, seed=None, options=None, t = 1000):

        self.t = t
        
        # reset all bodies, propagate back 4 days and reset covariance matrices
        for body in self.drifters + self.satellites:
            body.reset(seed)
            covariance = body.get_covariance_matrix()
            body.propagate_back(float(60 * 60 * 24 * 2))
            body.set_covariance_matrix(covariance)

        agent = self.satellites[0]
        drifter = self.drifters[0]
        poc, tca, closest_approach = env.predicted_poc()

        observations = {'agent': agent.get_equinoctial_position() + drifter.get_equinoctial_position() + [agent.get_fuel()] + [poc]}
        infos = {satellite.name: {} for satellite in self.satellites}

        return observations, infos

    def step(self, actions=None, last_observations=None):

        for drifter in self.drifters:
            drifter.step()

        satellite = self.satellites[0]

        # clipped_actions = {str(satellite): np.clip(actions[str(satellite)], [-1,-1,-1,-1], [1,1,1,1]) for satellite in env.satellites}
        scaled_actions = {str(satellite): ((actions[satellite.name] + 1) / 2) * high for satellite in self.satellites}

        satellite = self.satellites[0]
        action = scaled_actions[str(satellite)]

        if action[3] < 0.5:
            scaled_actions = {str(satellite): np.array([0.0,0.0,0.0]) for satellite in self.satellites}
        else:
            scaled_actions = {str(satellite): action[:-1] for satellite in self.satellites}

        observations = {satellite.name: satellite.step(scaled_actions[satellite.name], 10.0)}

        self.t -= 1

        agent = self.satellites[0]
        drifter = self.drifters[0]
        poc, tca, closest_approach = self.predicted_poc()
        observations = {'agent': agent.get_equinoctial_position() + drifter.get_equinoctial_position() + [agent.get_fuel()] + [poc]}

        rewards, terminations = self.rewards(observations, last_observations, actions)

        # observations = self.normalize_states(states)

        truncations = {satellite.name: False for satellite in env.satellites}
        infos = {satellite.name: {} for satellite in env.satellites}
        return observations, rewards, terminations, truncations, infos

    def rewards(self, observations, last_observations, actions):

        agent_name = self.satellites[0].name
        state = observations[agent_name]
        last_state = last_observations[agent_name]
        target = self.target

        last_poc = last_state[13]

        # if last POC was low, penalize the usage of thrust
        if last_poc < 1e-6:
            decision = actions[agent_name][3]
            reward = -100 if decision > 0 else 0
        # if last POC was high, penalize if POC is still high and penalize the distance from the nominal orbit
        else:
            elements_diff = np.abs(np.array(state[:5]) - np.array(target))
            elements_diff[0] /= target[0]
            elements_weights = np.array([1000, 1, 1, 10, 10])
            distance_penalty = elements_weights.T @ elements_diff
            poc = state[13]
            poc_penalty = 10 if poc > 1e-6 else 0
            reward = - (distance_penalty + poc_penalty)

        # print(elements_diff)
        # print(f'- ({distance_penalty} + {poc_penalty}) = {reward}')

        return {agent_name: reward / 100}, {agent_name: False}

    def predicted_poc_old(self):

        agent = env.satellites[0]
        drifter = env.drifters[0]

        agent.simulated_current_state = agent.current_state
        drifter.simulated_current_state = drifter.current_state

        agent.simulation_propagator.setInitialState(agent.simulated_current_state)
        drifter.simulation_propagator.setInitialState(drifter.simulated_current_state)

        agent_simulated_state = agent.simulated_current_state
        drifter_simulated_state = drifter.simulated_current_state

        # calculate time of closest approach (TCA)
        closest_step = 0
        closest_approach = 1e15
        for t in range(self.t):
            agent_pos = agent_simulated_state.getPVCoordinates().getPosition()
            drifter_pos = drifter_simulated_state.getPVCoordinates().getPosition()
            agent_vector = Vector3D([agent_pos.getX(), agent_pos.getY(), agent_pos.getZ()])
            drifter_vector = Vector3D([drifter_pos.getX(), drifter_pos.getY(), drifter_pos.getZ()])
            distance = Vector3D.distance(agent_vector, drifter_vector)
            if distance < closest_approach:
                closest_approach = distance
                closest_step = t
            agent_simulated_state = agent.simulation_propagator.propagate(agent_simulated_state.getDate().shiftedBy(agent.delta_t))
            drifter_simulated_state = drifter.simulation_propagator.propagate(drifter_simulated_state.getDate().shiftedBy(drifter.delta_t))
        tca = closest_step * agent.delta_t

        agent.simulation_propagator.setInitialState(agent.simulated_current_state)
        drifter.simulation_propagator.setInitialState(drifter.simulated_current_state)

        # put bodies on tca with simulation propagator
        agent.step_simulation(tca)
        drifter.step_simulation(tca)

        agent_pos = agent.simulated_current_state.getPVCoordinates().getPosition()
        drifter_pos = drifter.simulated_current_state.getPVCoordinates().getPosition()
        agent_vel = agent.simulated_current_state.getPVCoordinates().getVelocity()
        drifter_vel = drifter.simulated_current_state.getPVCoordinates().getVelocity()
        relative_pos = np.array([agent_pos.getX(), agent_pos.getY(), agent_pos.getZ()]) - np.array([drifter_pos.getX(), drifter_pos.getY(), drifter_pos.getZ()])
        relative_vel = np.array([agent_vel.getX(), agent_vel.getY(), agent_vel.getZ()]) - np.array([drifter_vel.getX(), drifter_vel.getY(), drifter_vel.getZ()])
        tca_linear = - float((relative_vel.T @ relative_pos) / (relative_vel.T @ relative_vel))
        chaser_pos = agent.simulated_current_state.getPVCoordinates().getPosition()
        target_pos = drifter.simulated_current_state.getPVCoordinates().getPosition()
        chaser_vel = agent.simulated_current_state.getPVCoordinates().getVelocity()
        target_vel = drifter.simulated_current_state.getPVCoordinates().getVelocity()
        chaser_pos = np.array([chaser_pos.getX(), chaser_pos.getY(), chaser_pos.getZ()])
        target_pos = np.array([target_pos.getX(), target_pos.getY(), target_pos.getZ()])
        chaser_vel = np.array([chaser_vel.getX(), chaser_vel.getY(), chaser_vel.getZ()])
        target_vel = np.array([target_vel.getX(), target_vel.getY(), target_vel.getZ()])
        chaser_cov = np.array(agent.get_simulated_covariance_matrix())[:3, :3]
        target_cov = np.array(drifter.get_simulated_covariance_matrix())[:3, :3]

        # calculate probability of collision (poc)
        poc = Body.poc_rederivation_simulation(chaser_pos, chaser_vel, chaser_cov, target_pos, target_vel, target_cov, tca_linear, agent.radius, drifter.radius)

        # put bodies back at current state
        #agent.step_simulation(-tca)
        #drifter.step_simulation(-tca)

        if poc is None:
            poc = 0.0

        return poc, tca, closest_approach
    
    def predicted_poc(self):

        agent = env.satellites[0]
        drifter = env.drifters[0]

        agent_state = agent.get_cartesian_position() + agent.get_cartesian_velocity()
        drifter_state = drifter.get_cartesian_position() + drifter.get_cartesian_velocity()

        # create copies of current bodies (with no uncertainty)
        agent_params = {"name": "agent",
                       "initial_state": {"x": agent_state[0], "y": agent_state[1], "z": agent_state[2], "x_dot": agent_state[3], "y_dot": agent_state[4], "z_dot": agent_state[5]},
                       "initial_state_uncertainty": {"x": 1e-10, "y": 1e-10, "z": 1e-10, "x_dot": 1e-10, "y_dot": 1e-10, "z_dot": 1e-10},
                       "initial_mass": agent.get_mass(),
                       "radius": agent.radius,
                       "save_steps_info": False,
                       "initial_date": agent.initial_date}
        
        drifter_params = {"name": "drifter",
                          "initial_state": {"x": drifter_state[0], "y": drifter_state[1], "z": drifter_state[2], "x_dot": drifter_state[3], "y_dot": drifter_state[4], "z_dot": drifter_state[5]},
                          "initial_state_uncertainty": {"x": 1e-10, "y": 1e-10, "z": 1e-10, "x_dot": 1e-10, "y_dot": 1e-10, "z_dot": 1e-10},
                          "initial_mass": drifter.get_mass(),
                          "radius": drifter.radius,
                          "save_steps_info": False,
                          "initial_date": drifter.initial_date}

        # simplest gravity model (NewtonianAttraction)
        gravity_model = NewtonianAttraction(MU)

        agent_sim = Body(agent_params)
        drifter_sim = Body(drifter_params)
        agent_sim.delta_t = agent.delta_t
        agent_sim.forces = [gravity_model]
        drifter_sim.delta_t = drifter.delta_t
        drifter_sim.forces = [gravity_model]

        agent_sim.reset()
        drifter_sim.reset()

        # in order to get the correct covariance propagation, set the original covariance matrix
        agent_sim.set_covariance_matrix(agent.get_covariance_matrix())
        drifter_sim.set_covariance_matrix(drifter.get_covariance_matrix())

        # calculate time of closest approach (TCA)
        closest_step = 0
        closest_approach = 1e15
        for t in range(self.t):
            distance = Body.get_distance(agent_sim, drifter_sim)
            if distance < closest_approach:
                closest_approach = distance
                closest_step = t
            agent_sim.step()
            drifter_sim.step()
        tca = closest_step * agent.delta_t

        agent_sim.reset()
        drifter_sim.reset()
        agent_sim.set_covariance_matrix(agent.get_covariance_matrix())
        drifter_sim.set_covariance_matrix(drifter.get_covariance_matrix())
        agent_sim.step(tca)
        drifter_sim.step(tca)

        # calculate probability of collision (poc) at tca
        poc = Body.poc_rederivation(agent_sim, drifter_sim)

        if poc is None:
            poc = 0.0

        return poc, tca, closest_approach

env = ColAvoidanceEnv(params)

time_step = 1
action_epsilon_decay_freq = 1000
current_action_epsilon = 0.5
min_action_epsilon = 0.1
action_epsilon_decay_rate = 0.05
target_update_freq = 10
update_freq = 256
update_counter = 0
episodes = 10000
steps_per_episode = int(float(60 * 60 * 24 * 2) / params["delta_t"] + 10)
low = np.array([0.0, 0.0, 0.0, 0.0])
high = np.array(params["satellites"][0]["agent"]["action_space"])

print(f'steps_per_episode: {steps_per_episode}')

# load model
# env.satellites[0].dqn.load('model_checkpoint.pth')
# env.satellites[0].ppo.load(".\\missions\\col_avoidance\\model_col_avoidance.pth")
best_score = -1e7 # -3808.6208931599754
start_episode = 1 # 39

# print("filling buffer...")
# while len(env.satellites[0].dqn.memory) < env.satellites[0].dqn.memory.size / 4:
#     observations, _ = env.reset(42, t=steps_per_episode)
#     for t in range(1, steps_per_episode + 1):
#         try:
#             actions_index = {str(satellite): satellite.dqn.select_action(observations[str(satellite)], 1.0) for satellite in env.satellites}
#             actions = {}
#             for satellite in env.satellites:
#                 index = actions_index[str(satellite)]
#                 if index == 0: # forward
#                     actions[str(satellite)] = np.array([1, -1, -1, 1])
#                 elif index == 1: # out (left)
#                     actions[str(satellite)] = np.array([1, 0, -1, 1])
#                 elif index == 2: # behind
#                     actions[str(satellite)] = np.array([1, 1, -1, 1])
#                 elif index == 3: # in (right)
#                     actions[str(satellite)] = np.array([1, 0, 0, 1])
#                 elif index == 4: # up
#                     actions[str(satellite)] = np.array([1, 0, -0.5, 1])
#                 elif index == 5: # down
#                     actions[str(satellite)] = np.array([1, 0, 0.5, 1])
#                 elif index == 6: # nothing
#                     actions[str(satellite)] = np.array([-1, -1, -1, -1])
#             # apply step
#             next_observations, rewards, terminations, _, _ = env.step(actions, observations)
#             # save transition
#             for satellite in env.satellites:
#                 mask = 1 if terminations[satellite.name] or t == steps_per_episode else 0
#                 reward = rewards[satellite.name]
#                 state = observations[satellite.name]
#                 next_state = next_observations[satellite.name]
#                 action = actions_index[satellite.name]
#                 satellite.dqn.memory.add((state, action, mask, next_state, reward))
#             observations = next_observations
            
#         except Exception as e:
#             traceback.print_exc()
#             break

# print("starting training...")
for episode in range(start_episode, start_episode + episodes + 1):
    start_step = time_step
    current_ep_rewards = {str(satellite): 0 for satellite in env.satellites}
    
    observations, _ = env.reset(42, t=steps_per_episode)
    #print("reset")
    for t in range(1, steps_per_episode + 1):
        try:

            if not env.satellites[0].has_fuel():
                break

            # select actions with policies
            actions_index = {str(satellite): satellite.dqn.select_action(observations[str(satellite)], current_action_epsilon) for satellite in env.satellites}

            actions = {}
            for satellite in env.satellites:
                index = actions_index[str(satellite)]
                if index == 0: # forward
                    actions[str(satellite)] = np.array([1, -1, -1, 1])
                elif index == 1: # out (left)
                    actions[str(satellite)] = np.array([1, 0, -1, 1])
                elif index == 2: # behind
                    actions[str(satellite)] = np.array([1, 1, -1, 1])
                elif index == 3: # in (right)
                    actions[str(satellite)] = np.array([1, 0, 0, 1])
                elif index == 4: # up
                    actions[str(satellite)] = np.array([1, 0, -0.5, 1])
                elif index == 5: # down
                    actions[str(satellite)] = np.array([1, 0, 0.5, 1])
                elif index == 6: # nothing
                    actions[str(satellite)] = np.array([-1, -1, -1, -1])

            # apply step
            next_observations, rewards, terminations, _, _ = env.step(actions, observations)

            for satellite in env.satellites:
                mask = 1 if terminations[satellite.name] or t == steps_per_episode else 0
                reward = rewards[satellite.name]
                state = observations[satellite.name]
                next_state = next_observations[satellite.name]
                action = actions_index[satellite.name]
                satellite.dqn.memory.add((state, action, mask, next_state, reward))

            observations = next_observations
            current_ep_rewards = {str(satellite): current_ep_rewards[str(satellite)] + rewards[str(satellite)] for satellite in env.satellites}

            if time_step % update_freq and time_step > 256:
                for satellite in env.satellites:
                    # if len(satellite.dqn.memory) > satellite.dqn.batch_size:
                    # if len(satellite.dqn.memory) > satellite.dqn.memory.size / 2:
                        loss = satellite.dqn.update()
                update_counter += 1

            if time_step % action_epsilon_decay_freq == 0 and current_action_epsilon > min_action_epsilon:
                current_action_epsilon -= action_epsilon_decay_rate

            # update target network
            if update_counter % target_update_freq == 0:
                # print('updating target...')
                for satellite in env.satellites:
                    satellite.dqn.Q_target.load_state_dict(satellite.dqn.Q.state_dict())

            time_step += 1

            if terminations[env.agents[0]]:
                break
            
        except Exception as e:
            traceback.print_exc()
            break

        if params["interface"]["show"]:
            env.render()

    # show scores at the end of episode
    for satellite in env.satellites:
        score = current_ep_rewards[str(satellite)]
        print(f'Episode {episode}: {score}, Fuel: {satellite.get_fuel()}, epsilon: {current_action_epsilon}')
        if score > best_score:
            print(f'>>>>>>>> Best score of {score} found. Saving the model.')
            best_score = score
            satellite.dqn.save(f"model_checkpoint")