import gym
from gym import spaces
import numpy as np
import pygame
from typing import Dict, List, Optional, Tuple, Union
from copy import copy
import math
from scipy.linalg import expm
from scipy.integrate import quad_vec

from action_masking.util.sets import Zonotope

# Define constants
GRAVITY = 9.81  # [m/s^2]


class LongQuadrotorEnv(gym.Env):
    """A longitudinal quadrotor environment representing [1].

    [1] I. M. Mitchell et al. "Invariant, Viability and Discriminating
        Kernel Under-Approximation via Zonotope Scaling", 2019,
        Proceedings of the 22nd ACM International Conference on Hybrid
        Systems: Computation and Control, pp. 268-269

    Characteristica:
        - bounded disturbances
        - random initial state
        - variable control frequency

    TODO: Create LongQuadrotorEnv base class to reduce code duplication in LongQuadrotorCoupledDynamicsEnv.
    """

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}

    def __init__(self,
                 dt: float = 0.05,
                 K: float = 0.89/1.4,
                 d0: float = 70,
                 d1: float = 17,
                 n0: float = 55,
                 w: np.ndarray= np.array([0.0,0.0]),
                 input_limits: np.ndarray = np.array(
                     [[13.931460674157302, -0.261799387799149],
                      [16.931460674157300, 0.261799387799149]]),
                 goal_distance: float = 0.01,
                 done_on_collision: bool = False,
                 done_on_goal_reached: bool = False,
                 state_constraints: Optional[np.ndarray] = np.array(
                     [[-1.7, 0.3, -0.8, -1, -np.pi/12, -np.pi/2],
                      [1.7, 2.0, 0.8, 1.0, np.pi/12, np.pi/2]]),
                 randomize_env: bool = True,
                 start_state: np.ndarray = np.array([0, 1, 0, 0, 0, 0], dtype=np.float32),
                 randomize_start_range: Optional[List[float]] = [0.1, 0.1, 0.1],
                 collision_reward: float = -1.0,
                 goal_reward: float = 1.0,
                 step_reward: float = -1.0,
                 rew_state_weight: float = 1.0,
                 rew_act_weight: float = 0.01,
                 reward_shaping: bool = True,
                 size: float = 3,
                 render_mode: Optional[str] = None,
                 seed: int = 42,
                 safe_region: Optional[Union[Zonotope, np.ndarray]] = None,
                 ):
        """Create the reach environment.

        State space:
        x = [x, z, dx, dz, theta, dtheta]

        Continuous dynamics:
        f[0] = x[2]
        f[1] = x[3]
        f[2] = u[0]*K*sin(x[4]) + w[0]
        f[3] = -g + u[0] * K * cos(x[4]) + w[1]
        f[4] = x[5]
        f[5] = -d0 * x[4] - d1 * x[5] + n0 * u[1]

        Discrete dynamics:
        x_{k+1} = x_{k} + f(x_{k}) * dt

        Args:
            dt (float): Time step of the discrete dynamics.
            K (float): Thrust coefficient.
            d0 (float): Angular gain coefficient (theta).
            d1 (float): Angular gain coefficient (dtheta).
            n0 (float): Torque input coefficient.
            w (ndarray): disturbance bound in x z - shape (2,).
            input_limits (np.ndarray): Input limits in the form [[min_u1, min_u2], [max_u1, max_u2]].
            goal_distance (float): Distance to goal to consider it reached.
            done_on_collision (bool): If true, the episode is finished if the quadrotor collides with the constraint
                set.
            done_on_goal_reached (bool): If true, the episode is finished if the goal is reached.
            state_constraints (Optional, np.ndarray): State constraints of the form
                [[min_x, min_z,...], [max_x, max_z,...]].
            randomize_env (bool): If true, the environment is randomized on reset.
            start_state (np.ndarray): Start state of the quadrotor.
            randomize_start_range (Optional, List[float]): Range of the random start state in x, z, and theta.
            collision_reward (float): Reward for exceeding the state constraints.
            goal_reward (float): Reward for being inside the goal region.
            step_reward (float): Reward for each step.
            rew_state_weight (float): Weight of the state reward.
            rew_act_weight (float): Weight of the action reward.
            reward_shaping (bool): If true, the reward is shaped by the distance to the goal.
            size (float): Size of the environment. Defualts to 3.
            render_mode (Optional, str): Optional render mode. Can be "human" or "rgb_array".
            seed (int): Random seed.
            safe_region (Optional, Union[Zonotope, np.ndarray]): Safe region (RCI set) of the env. Needed to sample safe initial states from.
        """
        self.dt = dt
        self.K = K
        self.d0 = d0
        self.d1 = d1
        self.n0 = n0
        self.w = w
        self.n_dim = 6
        self.action_dim = 2
        self.A = np.array(
            [[0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, GRAVITY, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1],
             [0, 0, 0, 0, -d0, -d1]])
        self.B = np.array(
            [[0, 0], [0, 0], [0, 0], [K, 0], [0, 0], [0, n0]])
        self.E = np.array(
            [[0, 0], [0, 0], [1, 0], [0, 1], [0, 0], [0, 0]])
        self.x_eq = np.array([0, 1, 0, 0, 0, 0])
        self.u_eq = np.array([GRAVITY/K,0])
        self.w_eq = np.array([0,0])
        self.A_d = expm(self.A*dt)
        temp = np.eye(self.A.shape[0]) * dt
        interval = temp
        i = 1
        while not np.all(np.abs(temp) < 2.2204e-16):
            temp = (np.linalg.matrix_power(self.A, i) / math.factorial(i+1)) * (dt**(i+1))  #temp * dt/(i+1) * self.A #
            interval = interval + temp
            i += 1
        #interval = np.ones(self.A.shape[0]) * dt + ( self.A / 2) * dt**2 + (np.linalg.matrix_power(self.A, 2) / 6)  * dt**3 + (np.linalg.matrix_power(self.A, 3) / 24)  * dt**4
        self.B_d = np.matmul(interval, self.B)
        self.E_d = np.matmul(interval, self.E)
        # For reward shaping only:
        # TODO: isn't this exactly the same as input_limits? Why is this defined twice with different variables?
        u_LOW = np.array([-1.5+GRAVITY/self.K, -math.pi/12])
        u_HIGH = np.array([1.5+GRAVITY/self.K, math.pi/12])
        self.u_range = np.array([u_LOW, u_HIGH])
        self.goal_distance = goal_distance
        self.done_on_collision = done_on_collision
        self.done_on_goal_reached = done_on_goal_reached
        self.state_constraints = np.array(state_constraints, dtype=np.float32)
        self.randomize_env = randomize_env
        self.start_state = np.array(start_state, dtype=np.float32)
        if randomize_start_range is None:
            self.randomize_start_range = np.array([0, 0, 0], dtype=np.float32)
        else:
            self.randomize_start_range = np.array(randomize_start_range, dtype=np.float32)
        self.collision_reward = collision_reward
        self.goal_reward = goal_reward
        self.step_reward = step_reward
        self.rew_state_weight = rew_state_weight
        self.rew_act_weight = rew_act_weight
        self.reward_shaping = reward_shaping
        self.render_mode = render_mode
        self.rnd_seed = seed
        self.np_random = np.random.RandomState(self.rnd_seed)
        if self.render_mode is not None:
            self.window_size = 512  # The size of the PyGame window
        self.size = size
        self.observation_space = spaces.Box(low=-self.size,
                                            high=self.size,
                                            shape=(self.n_dim,),
                                            dtype='float32',
                                            seed=self.rnd_seed)
        self.action_space = spaces.Box(low=input_limits[0],
                                       high=input_limits[1],
                                       shape=(self.action_dim,),
                                       dtype='float32',
                                       seed=self.rnd_seed)
        self.state = np.zeros((self.n_dim,), dtype='float32')
        self.goal = start_state
        # For reward shaping
        self.max_distance = np.linalg.norm(state_constraints[1]-state_constraints[0])
        self._collision = False
        self.safe_region = safe_region
        if randomize_env and safe_region is None:
            raise ValueError("If randomize_env is True, safe_region must be provided to sample from.")
        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is
        used to ensure that the environment is rendered at the correct
        framerate in human-mode. They will remain `None` until human-mode
        is used for the first time.
        """
        assert (render_mode is None or
                render_mode in self.metadata["render_modes"])
        self.render_mode = render_mode
        self.window = None
        self.clock = None
        self._goal_reached_counter = 0
        self._collision_counter = 0

    def _get_obs(self) -> np.ndarray:
        """Get the observation, which is the agent location."""
        return self.state

    def _get_info(self) -> Dict:
        """Get the info dictionary."""
        return {
            "distance": np.linalg.norm(
                self.state - self.goal
            ),
            "collision": self._collision,
            "n_goal_reached": self._goal_reached_counter,
            "n_collision": self._collision_counter
        }

    def _get_success(self) -> bool:
        """Get the information if the goal was reached.

        Returns:
            True if goal was reached, false otherwise.
        """
        return np.linalg.norm(self.goal - self.state) < self.goal_distance

    def _get_done(self) -> bool:
        """Return if the episode is finished.

        Returns:
            True if done flag should be set, false otherwise.
        """
        if self.done_on_collision and self._collision:
            return True
        if self.done_on_goal_reached and self._get_success():
            return True
        return False

    def _get_reward(self, act) -> float:
        """Compute the step reward."""
        reward = self.step_reward
        distance = np.linalg.norm(self.goal - self.state)
        goal_reached = distance < self.goal_distance
        if self._collision:
            reward += self.collision_reward
        if self.reward_shaping:
            dist = np.sum(self.rew_state_weight * distance)
            act_cost = np.mean((act-self.u_range[0])/(self.u_range[1]-self.u_range[0]))
            dist += np.sum(self.rew_act_weight * act_cost)
            reward += np.exp(-dist)
        else:
            reward += goal_reached * self.goal_reward
        return reward

    def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
        """
        Reset the environment.
        TODO: the seed argument cannot be passed through the DummyVecEnv wrapper in my opinion.
        """
        # We need the following line to seed self.np_random
        if seed is not None:
            self.np_random.seed(seed)
        else:
            self.np_random.seed(self.rnd_seed)
        self._collision = False
        # Choose the agent's and goal location.
        if self.randomize_env:
            # Sample safe initial state from RCI set -> This could also lead to state on the boundries of the RCI set which might be harder to learn from.
            self.state = self.safe_region.sample()

            # Previous implementation - while this yields to a random state around our predefined start state, it is not guaranteed to be safe
            # self.state = self.start_state
            # self.state[[0, 1, 4]] = self.state[[0, 1, 4]] + self.np_random.uniform(
            #     -self.randomize_start_range, self.randomize_start_range
            # )
        else:
            self.state = self.start_state

        observation = self._get_obs()

        if self.render_mode == "human":
            self._render_frame()
        self._goal_reached_counter = 0
        self._collision_counter = 0
        return observation

    def step(self,
             action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict]:
        """Run one timestep of the environment's dynamics.

        When end of episode is reached, you are responsible for
        calling `reset()` to reset this environment's state.

        Accepts an action and returns a tuple
            (observation, reward, done, info).

        Args:
            action (np.ndarray): an action provided by the agent

        Returns:
            observation (np.ndarray): observation: agent state.
            reward (float): amount of reward returned after previous action.
            done (bool): whether the episode has ended.
            info (dict): contains auxiliary diagnostic information.
        """
        assert not np.any(np.isnan(action)), "NaN in action"
        self._collision = False
        # TODO: Switch to bounded Gaussian noise?
        w = self.np_random.rand(2)
        w *= self.w
        if self.collision_check_fn(action):
            self._collision = True
            self._collision_counter += 1
        # Discrete Dynamics
        self.state = self.dynamics_fn(action, self.state, w)
        observation = self._get_obs()
        goal_reached = self._get_success()
        self._goal_reached_counter += int(goal_reached)
        info = self._get_info()
        reward = self._get_reward(action)
        done = self._get_done()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, done, info

    def dynamics_fn(self,
                    u: np.ndarray,
                    state: np.ndarray,
                    w: np.ndarray = np.array([0.0, 0.0])) -> np.ndarray:
        """Calculate the next state of the environment based on the given action and current state.

        State space:
        x = [x, z, dx, dz, theta, dtheta]

        Discrete dynamics:
        x_{t+1} = A_d (x_{t} - x*) + B_d (u_{t} - u*) + E_d (w_{t} -w*)

        where
            A_d = e^{A dt}
            B_d = e^{A dt} \int_0^dt e^{-A d\tau} d\tau B
            E_d = e^{A dt} \int_0^dt e^{-A d\tau} d\tau E

        Based on continuous dynamics:
        \dot{x}[0] = x[2]
        \dot{x}[1] = x[3]
        \dot{x}[2] = u[0]*K*sin(x[4]) + w[0]
        \dot{x}[3] = -g + u[0] * K * cos(x[4]) + w[1]
        \dot{x}[4] = x[5]
        \dot{x}[5] = -d0 * x[4] - d1 * x[5] + n0 * u[1]

        With linear taylor expansion around the equilibrium point *
        \dot{x} = A (x - x*) + B (u - u*) + E (w - w*)

        Args:
            u (np.ndarray): the action to execute.
            state (np.ndarray): the current state of the environment.
            w (np.ndarray): the noise to add to the state.
        """
        next_state = self.x_eq + self.A_d@(state - self.x_eq) + self.B_d@(u - self.u_eq) + self.E_d@(w - self.w_eq)
        """
        next_state = copy(state)
        next_state[0] = state[0] + self.dt * (state[2])
        next_state[1] = state[1] + self.dt * (state[3])
        next_state[2] = state[2] + self.dt * (u[0] * self.K * np.sin(state[4]) + w[0])
        next_state[3] = state[3] + self.dt * (-GRAVITY + u[0] * self.K * np.cos(state[4]) + w[1])
        next_state[4] = state[4] + self.dt * (state[5])
        next_state[5] = state[5] + self.dt * (-self.d0 * state[4] - self.d1 * state[5] + self.n0 * u[1])
        """
        return next_state

    def collision_check_fn(
        self,
        action: np.ndarray,
        w: np.ndarray = np.array([0.0, 0.0]),
    ) -> bool:
        """Return true if the given action would collide.

        Args:
            u (np.ndarray): the action to execute.
            w (np.ndarray): the noise to add to the state.
        Returns:
            bool: True if the action would collide, False otherwise.
        """
        next_state = self.dynamics_fn(action, self.state, w)
        return np.any(next_state <= self.state_constraints[0]) or np.any(next_state >= self.state_constraints[1])

    def render(self):
        """Render a frame."""
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):
        """Render a frame of the current state using rgb_array method."""
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode(
                (self.window_size, self.window_size))
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        agent_width = 30
        agent_height = 2
        target = (1/(2*self.size) *
                  np.array([self.goal[0] + self.size,
                            self.goal[1] + self.size]))
        agent = (1/(2*self.size) *
                 np.array([self.state[0] + self.size,
                           self.state[1] + self.size]))
        # First we draw the target
        points = []
        radius = math.sqrt((agent_height / 2)**2 + (agent_width / 2)**2)
        angle = math.atan2(agent_height / 2, agent_width / 2)
        angles = [angle, -angle + math.pi, angle + math.pi, -angle]
        rot_radians = self.goal[4]
        target_pos = self.window_size * np.array(target) - np.array([agent_width/2, agent_height/2])
        for angle in angles:
            y_offset = -1 * radius * math.sin(angle + rot_radians)
            x_offset = radius * math.cos(angle + rot_radians)
            points.append((target_pos[0] + x_offset, target_pos[1] + y_offset))
        pygame.draw.polygon(canvas, (120, 120, 120), points)
        # Now we draw the agent
        points = []
        radius = math.sqrt((agent_height / 2)**2 + (agent_width / 2)**2)
        angle = math.atan2(agent_height / 2, agent_width / 2)
        angles = [angle, -angle + math.pi, angle + math.pi, -angle]
        rot_radians = self.state[4]
        agent_pos = self.window_size * np.array(agent) - np.array([agent_width/2, agent_height/2])
        for angle in angles:
            y_offset = -1 * radius * math.sin(angle + rot_radians)
            x_offset = radius * math.cos(angle + rot_radians)
            points.append((agent_pos[0] + x_offset, agent_pos[1] + y_offset))
        pygame.draw.polygon(canvas, (255, 0, 0) if self._collision else (0, 255, 0), points)

        if self.render_mode == "human":
            # The following line copies our drawings from `canvas` to the
            # visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # We need to ensure that human-rendering occurs at the
            # predefined framerate.
            # The following line will automatically add a delay to keep
            # the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )

    def close(self):
        """Close the render window."""
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
