"""
This file describes a 3D quadrotor environment.
Based on the longitudinal quadrotor environment from provably_safe_env.
"""

import math
from typing import Dict, List, Optional, Tuple, Union

import gym
from gym import spaces
import numpy as np
import pygame
from scipy.linalg import expm

from action_masking.util.sets import Zonotope

# Define constants
GRAVITY = 9.81  # [m/s^2]


class Quadrotor3DEnv(gym.Env):
    """A longitudinal quadrotor environment representing [1].


    [1] The benchmarks represents a linear quadrotor model, as
        described in S. Kaynama and C.J. Tomlin, "Benchmark: Flight envelope
        protection in autonomous quadrotors" in Workshop on Applied
        Verification for Continuous and Hybrid Systems, 2014.

    Input Arguments:
       -x:     system states x (dimension: [nx,1])
       -u:     control inputs u (dimension: [nu,1])
       -w:     process disturbances w (dimension: [nw,1])
    
    TODO: Test this class! This class is untested as we do not have the correct RCI sets yet.
    """

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}

    def __init__(
        self,
        dt: float = 0.05,
        K: float = 0.89 / 1.4,
        w: np.ndarray = np.array([0.0, 0.0, 0.0]),
        input_limits: np.ndarray = np.array([[-9.81, -0.5, -0.5, -0.5], [2.38, 0.5, 0.5, 0.5]]),
        goal_distance: float = 0.01,
        done_on_collision: bool = False,
        done_on_goal_reached: bool = False,
        state_constraints: Optional[np.ndarray] = np.array(
            [
                [-3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -0.78539816, -0.78539816, -3.14159265, -3.0, -3.0, -3.0],
                [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 0.78539816, 0.78539816, 3.14159265, 3.0, 3.0, 3.0],
            ]
        ),
        randomize_env: bool = True,
        start_state: np.ndarray = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32),
        randomize_start_range: Optional[List[float]] = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        collision_reward: float = -1.0,
        goal_reward: float = 1.0,
        step_reward: float = -1.0,
        rew_state_weight: float = 1.0,
        rew_act_weight: float = 0.01,
        reward_shaping: bool = True,
        size: float = 3,
        render_mode: Optional[str] = None,
        seed: int = 42,
        safe_region: Optional[Union[Zonotope, np.ndarray]] = None,
    ):
        """Create the reach environment.

        State space:
        x = [x, y, z, dx, dy, dz, phi, theta, psi, dphi, dtheta, dpsi]


        Continuous dynamics:
        f[0] = x[3]
        f[1] = x[4]
        f[2] = x[5]
        f[3] = -g * x[7] + w[0]
        f[4] =  g * x[6] + w[1]
        f[5] = u[0] + w[2]
        f[6] = x[9]
        f[7] = x[10]
        f[8] = x[11]
        f[9] = u[1]
        f[10] = u[2]
        f[11] = u[3]


        Discrete dynamics:
        x_{k+1} = x_{k} + f(x_{k}) * dt

        Args:
            dt (float): Time step of the discrete dynamics.
            K (float): Thrust coefficient.
            w (ndarray): disturbance bound in x y z - shape (3,).
            input_limits (np.ndarray): Input limits in the form [[min_u1, min_u2, min_u3, min_u4], [max_u1, max_u2, min_u3, min_u4]].
            goal_distance (float): Distance to goal to consider it reached.
            done_on_collision (bool): If true, the episode is finished if the quadrotor collides with the constraint
                set.
            done_on_goal_reached (bool): If true, the episode is finished if the goal is reached.
            state_constraints (Optional, np.ndarray): State constraints of the form
                [[min_x, min_y,...], [max_x, max_y,...]].
            randomize_env (bool): If true, the environment is randomized on reset.
            start_state (np.ndarray): Start state of the quadrotor.
            randomize_start_range (Optional, List[float]): Range of the random start state in x, y, z, and phi, theta, psi.
            collision_reward (float): Reward for exceeding the state constraints.
            goal_reward (float): Reward for being inside the goal region.
            step_reward (float): Reward for each step.
            rew_state_weight (float): Weight of the state reward.
            rew_act_weight (float): Weight of the action reward.
            reward_shaping (bool): If true, the reward is shaped by the distance to the goal.
            size (float): Size of the environment. Defualts to 3.
            render_mode (Optional, str): Optional render mode. Can be "human" or "rgb_array".
            seed (int): Random seed.
            safe_region (Optional, Union[Zonotope, np.ndarray]): Safe region (RCI set) of the env. Needed to sample safe initial states from.
        """
        self.dt = dt
        self.K = K
        self.w = w
        self.n_dim = 12
        self.action_dim = 4
        self.A = np.array(
            [
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, -GRAVITY, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, GRAVITY, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        self.B = np.array(
            [
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 1, 0, 0],
                [0, 0, 1, 0],
                [0, 0, 0, 1],
            ]
        )
        self.E = np.array(
            [
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [1, 0, 0],
                [0, 1, 0],
                [0, 0, 1],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
            ]
        )

        # self.x_eq = np.array([0, 1, 0, 0, 0, 0]) # TODO
        # self.u_eq = np.array([GRAVITY/K,0]) # TODO
        # self.w_eq = np.array([0,0]) # TODO
        self.A_d = expm(self.A * dt)
        temp = np.eye(self.A.shape[0]) * dt
        interval = temp
        i = 1
        while not np.all(np.abs(temp) < 2.2204e-16):
            temp = (np.linalg.matrix_power(self.A, i) / math.factorial(i + 1)) * (dt ** (i + 1))  # temp * dt/(i+1) * self.A #
            interval = interval + temp
            i += 1
        # interval = np.ones(self.A.shape[0]) * dt + ( self.A / 2) * dt**2 + (np.linalg.matrix_power(self.A, 2) / 6)  * dt**3 + (np.linalg.matrix_power(self.A, 3) / 24)  * dt**4
        self.B_d = np.matmul(interval, self.B)
        self.E_d = np.matmul(interval, self.E)
        # For reward shaping only:
        self.u_range = np.array([input_limits[0], input_limits[1]])
        self.goal_distance = goal_distance
        self.done_on_collision = done_on_collision
        self.done_on_goal_reached = done_on_goal_reached
        self.state_constraints = np.array(state_constraints, dtype=np.float32)
        self.randomize_env = randomize_env
        self.start_state = np.array(start_state, dtype=np.float32)
        if randomize_start_range is None:
            self.randomize_start_range = np.array([0, 0, 0, 0, 0, 0], dtype=np.float32)
        else:
            self.randomize_start_range = np.array(randomize_start_range, dtype=np.float32)
        self.collision_reward = collision_reward
        self.goal_reward = goal_reward
        self.step_reward = step_reward
        self.rew_state_weight = rew_state_weight
        self.rew_act_weight = rew_act_weight
        self.reward_shaping = reward_shaping
        self.render_mode = render_mode
        self.rnd_seed = seed
        self.np_random = np.random.RandomState(self.rnd_seed)
        if self.render_mode is not None:
            self.window_size = 512  # The size of the PyGame window
        self.size = size
        self.observation_space = spaces.Box(
            low=-self.size, high=self.size, shape=(self.n_dim,), dtype="float32", seed=self.rnd_seed
        )
        self.action_space = spaces.Box(
            low=input_limits[0], high=input_limits[1], shape=(self.action_dim,), dtype="float32", seed=self.rnd_seed
        )
        self.state = np.zeros((self.n_dim,), dtype="float32")
        self.goal = start_state
        # For reward shaping
        self.max_distance = np.linalg.norm(state_constraints[1] - state_constraints[0])
        self._collision = False
        self.safe_region = safe_region
        if randomize_env and safe_region is None:
            raise ValueError("If randomize_env is True, safe_region must be provided to sample from.")
        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is
        used to ensure that the environment is rendered at the correct
        framerate in human-mode. They will remain `None` until human-mode
        is used for the first time.
        """
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
        self.window = None
        self.clock = None
        self._goal_reached_counter = 0
        self._collision_counter = 0

    def _get_obs(self) -> np.ndarray:
        """Get the observation, which is the agent location."""
        return self.state

    def _get_info(self) -> Dict:
        """Get the info dictionary."""
        return {
            "distance": np.linalg.norm(self.state - self.goal),
            "collision": self._collision,
            "n_goal_reached": self._goal_reached_counter,
            "n_collision": self._collision_counter,
        }

    def _get_success(self) -> bool:
        """Get the information if the goal was reached.

        Returns:
            True if goal was reached, false otherwise.
        """
        return np.linalg.norm(self.goal - self.state) < self.goal_distance

    def _get_done(self) -> bool:
        """Return if the episode is finished.

        Returns:
            True if done flag should be set, false otherwise.
        """
        if self.done_on_collision and self._collision:
            return True
        if self.done_on_goal_reached and self._get_success():
            return True
        return False

    def _get_reward(self, act) -> float:
        """Compute the step reward."""
        reward = self.step_reward
        distance = np.linalg.norm(self.goal - self.state)
        goal_reached = distance < self.goal_distance
        if self._collision:
            reward += self.collision_reward
        if self.reward_shaping:
            dist = np.sum(self.rew_state_weight * distance)
            act_cost = np.mean((act - self.u_range[0]) / (self.u_range[1] - self.u_range[0]))
            dist += np.sum(self.rew_act_weight * act_cost)
            reward += np.exp(-dist)
        else:
            reward += goal_reached * self.goal_reward
        return reward

    def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
        """
        Reset the environment.
        TODO: the seed argument cannot be passed through the DummyVecEnv wrapper in my opinion.
        """
        # We need the following line to seed self.np_random
        if seed is not None:
            self.np_random.seed(seed)
        else:
            self.np_random.seed(self.rnd_seed)
        self._collision = False
        # Choose the agent's and goal location.
        if self.randomize_env:
            # Sample safe initial state from RCI set -> This could also lead to state on the boundries of the RCI set which might be harder to learn from.
            self.state = self.safe_region.sample()

            # Previous implementation - while this yields to a random state around our predefined start state, it is not guaranteed to be safe
            # self.state = self.start_state
            # self.state[[0, 1, 2, 6, 7, 8]] = self.state[[0, 1, 2, 6, 7, 8]] + self.np_random.uniform(
            #     -self.randomize_start_range, self.randomize_start_range
            # )
        else:
            self.state = self.start_state

        observation = self._get_obs()

        if self.render_mode == "human":
            self._render_frame()
        self._goal_reached_counter = 0
        self._collision_counter = 0
        return observation

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict]:
        """Run one timestep of the environment's dynamics.

        When end of episode is reached, you are responsible for
        calling `reset()` to reset this environment's state.

        Accepts an action and returns a tuple
            (observation, reward, done, info).

        Args:
            action (np.ndarray): an action provided by the agent

        Returns:
            observation (np.ndarray): observation: agent state.
            reward (float): amount of reward returned after previous action.
            done (bool): whether the episode has ended.
            info (dict): contains auxiliary diagnostic information.
        """
        assert not np.any(np.isnan(action)), "NaN in action"
        self._collision = False
        # TODO: Switch to bounded Gaussian noise?
        #w = self.np_random.rand(3)
        #w *= self.w # TODO: might adapt to be specific for each dimension so not one bounded value but for each dimension different values
        if self.collision_check_fn(action):
            self._collision = True
            self._collision_counter += 1
        # Discrete Dynamics
        self.state = self.dynamics_fn(action, self.state)
        observation = self._get_obs()
        goal_reached = self._get_success()
        self._goal_reached_counter += int(goal_reached)
        info = self._get_info()
        reward = self._get_reward(action)
        done = self._get_done()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, done, info

    def dynamics_fn(self, u: np.ndarray, state: np.ndarray, w: np.ndarray = np.array([0.0, 0.0, 0.0])) -> np.ndarray:
        """
        Calculate the next state of the environment based on the given action and current state.

        State space:
        x = [x, y, z, dx, dy, dz, phi, theta, psi, dphi, dtheta, dpsi]

        Discrete dynamics:
        x_{t+1} = A_d * x_{t} + B_d * u_{t} + E_d * w_{t}

        where
            A_d = e^{A dt}
            B_d = e^{A dt} \int_0^dt e^{-A d\tau} d\tau B
            E_d = e^{A dt} \int_0^dt e^{-A d\tau} d\tau E

        Based on continuous dynamics:
        \dot{x}[0] = x[3]
        \dot{x}[1] = x[4]
        \dot{x}[2] = x[5]
        \dot{x}[3] = -g * x[7] + w[0]
        \dot{x}[4] =  g * x[6] + w[1]
        \dot{x}[5] = u[0] + w[2]
        \dot{x}[6] = x[9]
        \dot{x}[7] = x[10]
        \dot{x}[8] = x[11]
        \dot{x}[9] = u[1]
        \dot{x}[10] = u[2]
        \dot{x}[11] = u[3]

        Args:
            u (np.ndarray): the action to execute.
            state (np.ndarray): the current state of the environment.
            w (np.ndarray): the noise to add to the state.
        """
        next_state = self.A_d @ state + self.B_d @ u + self.E_d @ w
        return next_state

    def collision_check_fn(
        self,
        action: np.ndarray,
        w: np.ndarray = np.array([0.0, 0.0, 0.0]),
    ) -> bool:
        """Return true if the given action would collide.

        Args:
            u (np.ndarray): the action to execute.
            w (np.ndarray): the noise to add to the state.
        Returns:
            bool: True if the action would collide, False otherwise.
        """
        next_state = self.dynamics_fn(action, self.state)
        return np.any(next_state <= self.state_constraints[0]) or np.any(next_state >= self.state_constraints[1])

    def render(self):
        """Render a frame."""
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):
        """
        Render a frame of the current state using rgb_array method.
        # TODO: adapt render method to 3D
        """
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        agent_width = 30
        agent_height = 2
        target = 1 / (2 * self.size) * np.array([self.goal[0] + self.size, self.goal[1] + self.size])
        agent = 1 / (2 * self.size) * np.array([self.state[0] + self.size, self.state[1] + self.size])
        # First we draw the target
        points = []
        radius = math.sqrt((agent_height / 2) ** 2 + (agent_width / 2) ** 2)
        angle = math.atan2(agent_height / 2, agent_width / 2)
        angles = [angle, -angle + math.pi, angle + math.pi, -angle]
        rot_radians = self.goal[4]
        target_pos = self.window_size * np.array(target) - np.array([agent_width / 2, agent_height / 2])
        for angle in angles:
            y_offset = -1 * radius * math.sin(angle + rot_radians)
            x_offset = radius * math.cos(angle + rot_radians)
            points.append((target_pos[0] + x_offset, target_pos[1] + y_offset))
        pygame.draw.polygon(canvas, (120, 120, 120), points)
        # Now we draw the agent
        points = []
        radius = math.sqrt((agent_height / 2) ** 2 + (agent_width / 2) ** 2)
        angle = math.atan2(agent_height / 2, agent_width / 2)
        angles = [angle, -angle + math.pi, angle + math.pi, -angle]
        rot_radians = self.state[4]
        agent_pos = self.window_size * np.array(agent) - np.array([agent_width / 2, agent_height / 2])
        for angle in angles:
            y_offset = -1 * radius * math.sin(angle + rot_radians)
            x_offset = radius * math.cos(angle + rot_radians)
            points.append((agent_pos[0] + x_offset, agent_pos[1] + y_offset))
        pygame.draw.polygon(canvas, (255, 0, 0) if self._collision else (0, 255, 0), points)

        if self.render_mode == "human":
            # The following line copies our drawings from `canvas` to the
            # visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # We need to ensure that human-rendering occurs at the
            # predefined framerate.
            # The following line will automatically add a delay to keep
            # the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))

    def close(self):
        """Close the render window."""
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
