"""
This a modified version of the CartPole environment from Gymnasium, with continuous action space and state space.
The original evn can be found at:
https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py
"""
import copy
import math
from typing import Optional, Union

import numpy as np

import gymnasium as gym
import dataclasses
from gymnasium import spaces
from gymnasium.error import DependencyNotInstalled
import matplotlib.pyplot as plt
from agent.model_based.model_based_design_cartpole import MATRIX_P

# We define the configuration classes to make the environment configurable using configuration files
@dataclasses.dataclass
class RobotConfig:
    # Robot Settings, the values are measured from the real cart-pole system produce by Quanser
    # https://www.quanser.com/products/linear-servo-base-unit-inverted-pendulum/
    x_threshold: float = 0.35
    theta_dot_threshold: float = 15
    gravity: float = 9.8
    mass_cart: float = 0.94
    mass_pole: float = 0.23
    action_mag: float = 10.0
    length: float = 0.64
    friction_cart: float = 10
    friction_pole: float = 0.0011
    with_friction: bool = True

@dataclasses.dataclass
class TaskConfig:
    # Task Settings
    action_penalty: float = 0.01
    crash_penalty: float = 0
    ini_states: list = dataclasses.field(default_factory=lambda: [0.0, 0.0, math.pi, 0.0])
    control_goal_x: float = 0.0  # x position of the control goal
    control_goal_theta: float = 0.0  # angle of the control goal
    max_episode_steps: int = 500
    evaluation_period: int = 10000
    num_episodes_to_run: int = 10  # number of episodes to evaluate the agent or collect trajectories
    task_reset_mode: str = 'random'
    change_dynamics: bool = False
    context_horizon: int = 10
    task_type: str = 'swing_up'  # 'swing_up' or 'balance' or 'position-tracking'
    reward_type: str = 'exp'  # 'exp' or 'split' or 'safety'


@dataclasses.dataclass
class SimulationConfig:
    # Simulation Settings
    random_reset_train: bool = True
    random_reset_eval: bool = False
    num_action_repeat: int = 20
    sim_time_step: float = 0.001
    enable_rendering: bool = False
    kinematics_integrator: str = 'euler'
    render_mode: str = 'human'


@dataclasses.dataclass
class CartPoleGymConfig:
    RobotParams: RobotConfig = dataclasses.field(default_factory=RobotConfig)
    TaskParams: TaskConfig = dataclasses.field(default_factory=TaskConfig)
    SimulationParams: SimulationConfig = dataclasses.field(default_factory=SimulationConfig)


class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
    """
    ## Description

    This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in
    ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
    A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
    The pendulum is placed upright on the cart and the goal is to swing up and balance the pole by applying forces
    in the left and right direction on the cart.

    ## Action Space

    The action is a `ndarray` with shape `(1,)` which can take values in `[-1, 1] * action_mag` indicating the direction
     of the fixed force the cart is pushed with. Action_mag paramter is used for scaling up/down the action space,
     the default value is 5.


    ## Observation Space
    The state of ths system is a `ndarray` with shape `(4,)` with the values corresponding to the positions
    and velocity of the cart and angle and angular velocity of the pole. We convert the angle to sin and cos to avoid
    discontinuities in the observation space

    The observation is a `ndarray` with shape `(5,)` list as follows:

    | Num | Observation           | Min                   | Max               |
    |-----|-----------------------|-----------------------|-------------------|
    | 0   | Cart Position         | -0.35                 | 0.35               |
    | 1   | Cart Velocity         | -Inf                  | Inf               |
    | 2   | sin (Pole Angle)      | -1 ~ -pi/2 rad (-90°) | 1 ~ pi/2 rad (90°)|
    | 3   | cos (Pole Angle)      | -1 ~ -pi rad (-180°)  | 1 ~ pi rad (180°) |
    | 4   | Pole Angular Velocity | -15 rad/s             | 15 rad/s          |

    **Note:** The observation space defined here is different from the original cartpole environment in the following ways:
    1) We use sin and cos of the pole angle instead of the angle itself to avoid discontinuities in the observation space.
    2) We expand the range of the angle to `[-pi, pi]` to enable the task of swing up.


    ## Rewards
    The reward is defined a expoential function of the negtive distance
    between the tip of the pole at the current position and the control goal [0, 0]


    ## Starting State
    Random Reset: The initial state is generated by sampling each of the four state variables from a uniform distribution.
    Default Reset: The cart pole is reset at the natual equilibrium point with the pole pointing downwards. [0,0,pi,0]


    ## Episode End
    The episode ends if any one of the following occurs:

    1. Termination: Cart Position is greater than ±0.5 (center of the cart reaches the edge of the display)
    2. Truncation: Episode length is greater than 500

    ## Arguments

    Cartpole only has `render_mode` as a keyword for `gymnasium.make`.
    On reset, the `options` parameter allows the user to change the bounds used to determine the new random state.

    ```python
    >>> import gymnasium as gym
    >>> env = gym.make("CartPole-v1", render_mode="rgb_array")
    >>> env
    <TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
    >>> env.reset(seed=123, options={"low": -0.1, "high": 0.1})  # default low=-0.05, high=0.05
    (array([ 0.03647037, -0.0892358 , -0.05592803, -0.06312564], dtype=float32), {})

    ```
    ## Vectorized environment

    To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor.

    ```python
    >>> import gymnasium as gym
    >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="vector_entry_point")
    >>> envs
    CartPoleVectorEnv(CartPole-v1, num_envs=3)
    >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
    >>> envs
    SyncVectorEnv(CartPole-v1, num_envs=3)
    """

    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 200,
    }

    def __init__(self, env_config: CartPoleGymConfig):
        self.env_config = env_config

        self.gravity = self.env_config.RobotParams.gravity
        self.masscart = self.env_config.RobotParams.mass_cart
        self.masspole = self.env_config.RobotParams.mass_pole
        self.pole_length = self.env_config.RobotParams.length

        self.total_mass = self.masspole + self.masscart
        self.half_length = self.pole_length / 2
        self.pole_mass_length_half = self.masspole * self.half_length
        self.force_mag = self.env_config.RobotParams.action_mag

        # tau = 1/f seconds between state updates
        self.tau = self.env_config.SimulationParams.sim_time_step * self.env_config.SimulationParams.num_action_repeat
        self.kinematics_integrator = self.env_config.SimulationParams.kinematics_integrator

        self.x_threshold = self.env_config.RobotParams.x_threshold
        self.theta_threshold_radians = 0.785 # 45 degrees

        # observation space
        high = np.array(
            [
                self.x_threshold,  # cart position
                np.finfo(np.float32).max,  # cart velocity
                1,  # pole angle sin  max = sin(pi/2)
                1,  # pole angle cos  max = cos(0)
                np.finfo(np.float32).max,  # pole angular velocity
            ],
            dtype=np.float32,
        )
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        # continuous action space for drl action, will be scaled up/down by force_mag in the step function
        self.action_space = spaces.Box(-1, 1, shape=(1,), dtype=np.float32)

        self.render_mode = self.env_config.SimulationParams.render_mode

        self.screen_width = 600
        self.screen_height = 400
        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None

        self.steps_beyond_terminated = None
        self._elapsed_steps = 0
        self.control_goal = [self.env_config.TaskParams.control_goal_x, self.env_config.TaskParams.control_goal_theta]  # [x, theta]
        # self.control_goal_array = np.array([self.control_goal[0], 0, self.control_goal[1], 0])
        self.model_based_equilibrium = np.array([0, 0, 0, 0], dtype=np.float32)
        self.reward_function = self.get_reward_function()

    def step(self, action):
        assert self.action_space.contains(
            [action]
        ), f"{action!r} ({type(action)}) invalid"
        assert self.state is not None, "Call reset before using step method."
        x, x_dot, theta, theta_dot = self.state

        force = action * self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        # For the interested reader:
        # https://coneural.org/florian/papers/05_cart_pole.pdf

        if self.env_config.RobotParams.with_friction:
            """ with friction"""
            temp \
                = (force + self.pole_mass_length_half * theta_dot ** 2 *
                   sintheta - self.env_config.RobotParams.friction_cart * x_dot) / self.total_mass

            thetaacc = \
                (self.env_config.RobotParams.gravity * sintheta - costheta * temp -
                 self.env_config.RobotParams.friction_pole * theta_dot / self.pole_mass_length_half) / \
                (self.half_length * (4.0 / 3.0 - self.env_config.RobotParams.mass_pole * costheta ** 2
                                     / self.total_mass))

            xacc = temp - self.pole_mass_length_half * thetaacc * costheta / self.total_mass

        else:
            """without friction"""

            temp = (force + self.pole_mass_length_half * theta_dot ** 2 * sintheta) / self.total_mass
            thetaacc = (self.env_config.RobotParams.gravity * sintheta - costheta * temp) / \
                       (self.half_length * (4.0 / 3.0 - self.env_config.RobotParams.mass_pole
                                            * costheta ** 2 / self.total_mass))
            xacc = temp - self.pole_mass_length_half * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == "euler":
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:  # semi-implicit euler
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        self.state = (x, x_dot, theta, theta_dot)

        # add balancing task termination condition
        if self.env_config.TaskParams.task_type == 'balance':
            terminated = bool(
                x < -self.x_threshold
                or x > self.x_threshold
                or theta < -self.theta_threshold_radians
                or theta > self.theta_threshold_radians
            )
        elif self.env_config.TaskParams.task_type == 'swing_up':
            terminated = bool(
                x < -self.x_threshold
                or x > self.x_threshold
                or theta_dot < - self.env_config.RobotParams.theta_dot_threshold
                or theta_dot > self.env_config.RobotParams.theta_dot_threshold
            )
        else:
            raise ValueError(f"Unknown task type: {self.env_config.TaskParams.task_type}")

        self._elapsed_steps += 1

        truncated = bool(self._elapsed_steps >= self.env_config.TaskParams.max_episode_steps)
        reward = self.reward_function(self.state, action)
        if self.render_mode == "human" and self.env_config.SimulationParams.enable_rendering:
            self.render()

        observations = np.array([x, x_dot, math.sin(theta), math.cos(theta), theta_dot])
        return np.array(observations, dtype=np.float32), reward, terminated, truncated, {}

    def reward_function_exp(self, state, action):
        x, x_dot, theta, theta_dot = state
        observations = np.array([x, x_dot, math.sin(theta), math.cos(theta), theta_dot])
        distance_score = self.get_distance_score(observations, self.control_goal)
        distance_reward = distance_score
        action_penalty = -1 * self.env_config.TaskParams.action_penalty * action * action
        r = distance_reward + action_penalty
        return r

    def reward_function_split(self, state, action):
        x, x_dot, theta, theta_dot = state
        distance_x = abs(self.control_goal[0] - x)**2
        distance_theta = abs(self.control_goal[1] - theta)**2
        w_x = - 0.5
        w_theta = - 0.5
        distance_reward = w_x * distance_x + w_theta * distance_theta
        action_reward = -1 * self.env_config.TaskParams.action_penalty * action * action
        r = distance_reward + action_reward
        return r

    def reward_function_safety(self, state, action):
        x, x_dot, theta, theta_dot = state
        distance_x = abs(self.control_goal[0] - x)**2
        distance_theta = abs(self.control_goal[1] - theta)**2
        w_x = - 0.5
        w_theta = - 0.5
        distance_reward = w_x * distance_x + w_theta * distance_theta
        action_reward = -1 * self.env_config.TaskParams.action_penalty * action * action
        tracking_error = self.state - self.model_based_equilibrium
        envelope_triggering_penalty = -1 if tracking_error @ MATRIX_P @ tracking_error > 1 else 0
        r = distance_reward + action_reward + envelope_triggering_penalty
        return r

    def get_reward_function(self):
        if self.env_config.TaskParams.reward_type == 'exp':
            return self.reward_function_exp
        elif self.env_config.TaskParams.reward_type == 'split':
            return self.reward_function_split
        elif self.env_config.TaskParams.reward_type == 'safety':
            return self.reward_function_safety


    def reset(
            self,
            *,
            seed: Optional[int] = None,
            options: Optional[dict] = None,
            scale: Optional[np.ndarray] = None,
    ):
        super().reset(seed=seed)
        if options is not None and "reset_state" in options.keys():
            self.state = options['reset_state']
        else:
            if self.env_config.TaskParams.task_reset_mode == 'random':
                ran_x = np.random.uniform(-0.4* self.env_config.RobotParams.x_threshold,
                                          0.4 * self.env_config.RobotParams.x_threshold) # todo debug mode

                ran_v = 0

                if self.env_config.TaskParams.task_type == 'swing_up':
                    ran_theta = np.random.normal(math.pi, 0.8) # for swing up task
                else:
                    ran_theta = np.random.uniform(-0.2, 0.2)  # for balancing task

                ran_theta_v = 0

                self.state = [ran_x, ran_v, ran_theta, ran_theta_v]
            else:
                self.state = self.env_config.TaskParams.ini_states  # reset to the pointing down equilibrium

        self.steps_beyond_terminated = None
        self._elapsed_steps = 0

        if self.render_mode == "human" and self.env_config.SimulationParams.enable_rendering:
            self.render()

        context_vector = np.array([self.masspole, self.masscart, self.pole_length], dtype=np.float32)
        if self.env_config.TaskParams.change_dynamics and scale is not None:
            # override the context vector with the new values
            context_vector = self.change_physics_contexts(scale)

        x, x_dot, theta, theta_dot = self.state
        observations = np.array([x, x_dot, math.sin(theta), math.cos(theta), theta_dot])
        return np.array(observations, dtype=np.float32), context_vector

    def render(self):
        if self.render_mode is None:
            assert self.spec is not None
            gym.logger.warn(
                "You are calling render method without specifying any render mode. "
                "You can specify the render_mode at initialization, "
                f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
            )
            return

        try:
            import pygame
            from pygame import gfxdraw
        except ImportError as e:
            raise DependencyNotInstalled(
                'pygame is not installed, run `pip install "gymnasium[classic-control]"`'
            ) from e

        if self.screen is None:
            pygame.init()
            if self.render_mode == "human":
                pygame.display.init()
                self.screen = pygame.display.set_mode(
                    (self.screen_width, self.screen_height)
                )
            else:  # mode == "rgb_array"
                self.screen = pygame.Surface((self.screen_width, self.screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        world_width = self.x_threshold * 2
        scale = self.screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.half_length) * 0.5
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        self.surf = pygame.Surface((self.screen_width, self.screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + self.screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, self.screen_width, carty, (0, 0, 0))

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        if self.render_mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            pygame.display.flip()

        elif self.render_mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )

    def close(self):
        if self.screen is not None:
            import pygame
            pygame.display.quit()
            pygame.quit()
            self.isopen = False

    # Performance evaluation metrics
    def get_distance_score(self, observations, control_goal):
        distance_score_factor = 5  # to adjust the exponential gradients
        cart_position = observations[0]
        pendulum_angle_sin = observations[2]
        pendulum_angle_cos = observations[3]

        target_cart_position = control_goal[0]
        target_pendulum_angle = control_goal[1]

        pendulum_length = self.env_config.RobotParams.length

        pendulum_tip_position = np.array(
            [cart_position + pendulum_length * pendulum_angle_sin, pendulum_length * pendulum_angle_cos])

        target_tip_position = np.array(
            [target_cart_position + pendulum_length * np.sin(target_pendulum_angle),
             pendulum_length * np.cos(target_pendulum_angle)])

        distance = np.linalg.norm(target_tip_position - pendulum_tip_position)
        distance_score = np.exp(-distance * distance_score_factor)

        return distance_score

    def get_performance_score(self):
        control_goal = self.control_goal
        x, x_dot, theta, theta_dot = self.state
        observation = np.array([x, x_dot, math.sin(theta), math.cos(theta), theta_dot])
        return self.get_distance_score(observation, control_goal)

    def change_physics_contexts(self, scale):
        """ Handles randomization of the environment parameters outside the agent's control """

        new_mass_pole = scale[0] * self.env_config.RobotParams.mass_pole
        new_mass_cart = scale[1] * self.env_config.RobotParams.mass_cart
        new_pole_length = scale[2] * self.env_config.RobotParams.length

        self.masspole = new_mass_pole
        self.masscart = new_mass_cart
        self.pole_length = new_pole_length
        self.total_mass = self.masspole + self.masscart
        self.half_length = self.pole_length / 2
        self.pole_mass_length_half = self.masspole * self.half_length

        return np.array([new_mass_pole, new_mass_cart, new_pole_length], dtype=np.float32)

    def set_control_goal_array(self):
        self.control_goal_array = np.array([self.control_goal[0], 0, self.control_goal[1], 0])

    def generate_control_tasks(self, mode='train'):
        # the control task is characterized by the initial-state and the control goal
        task_list = []
        if mode == 'train':
            x_range = self.env_config.RobotParams.x_threshold * 0.5
            x_left = - 1 * x_range
            x_right = - x_left
            x_init_list = np.random.uniform(x_left, x_right, 100)
            x_goal_list = np.random.uniform(x_left, x_right, 100)
            for x_init, x_goal in zip(x_init_list, x_goal_list):
                ran_theta = np.random.uniform(-0.25, 0.25)
                task_list.append({'init_states': [0, 0, ran_theta, 0], 'control_goal': [x_goal, 0]})
        else:
            x_range = self.env_config.RobotParams.x_threshold * 0.8
            x_left = - 1 * x_range
            x_right = - x_left
            # we fix the initial state and control goal for evaluation
            x_goal_list = np.linspace(start=x_left, stop=x_right, num=10)
            for x_goal in x_goal_list:
                ran_theta = np.random.uniform(-0.25, 0.25)
                task_list.append({'init_states': [0, 0, ran_theta, 0], 'control_goal': [x_goal, 0]})
                task_list.append({'init_states': [0, 0, ran_theta, 0], 'control_goal': [x_goal, 0]})
        return task_list


    def plot_task_result(self, task_list):
        fig = plt.figure(figsize=(6, 6))
        for task_info in task_list:
            x, _, y, _ = task_info['init_states']
            x1, y1 = task_info['control_goal']
            crash = task_info['crash']
            x_final, _, y_final, _  = task_info["final_state"]

            dx = x_final - x
            dy = y_final - y
            color = 'red' if crash else 'blue'

            # Draw arrow from (x, y) to (x1, y1)
            plt.arrow(x, y, dx, dy, color=color, width=0.005, length_includes_head=True)
            plt.scatter(x, y, marker='o', color='black', s=50)
            plt.scatter(x_final, y_final, marker='^', color='black', s=50)
            plt.scatter(x1, y1, marker='*', color='green', s=50)

        # Configure axes
        # plt.xlim(-1.0, 1.0)  # Limit x-axis to [-1, 1]
        # plt.ylim(-0.15, 0.15)
        plt.xlabel('X')
        plt.ylabel('$theta$')
        plt.grid(True)  # Add grid lines
        from matplotlib.patches import Patch
        from matplotlib.lines import Line2D

        # Create custom legend handles:
        # For arrow colors (success and failure)
        success_patch = Patch(color='blue', label='Success')
        failure_patch = Patch(color='red', label='Failure')
        # For markers (start and goal)
        start_marker = Line2D([0], [0], marker='o', color='w', label='Start',
                              markerfacecolor='black', markersize=8)
        end_marker = Line2D([0], [0], marker='^', color='w', label='End',
                             markerfacecolor='black', markersize=8)
        goal_marker = Line2D([0], [0], marker='*', color='green', label='Goal',
                             markerfacecolor='black', markersize=8)

        # Combine all legend handles
        plt.legend(handles=[success_patch, failure_patch, start_marker, end_marker, goal_marker])
        plt.title('Multi-Task Control for CartPole')
        return fig

    def predict_next_state(self, action_input, System_A_matrix, System_B_matrix):
        action_input = np.array([action_input]).transpose()
        current_state = np.array(copy.deepcopy(self.state) - self.control_goal_array).transpose()
        next_state = System_A_matrix @ current_state + System_B_matrix.transpose() @ action_input
        next_state = next_state.transpose() + self.control_goal_array
        return next_state




