# ------------------------- IMPORTS -------------------------
from abc import ABC, abstractmethod
import math
from typing import List, Union
import gym
from gym.wrappers import TimeLimit
import numpy as np

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.monitor import Monitor



# TODO: remove this import when circular import issue is resolved
# this import is necessary for the following imports to work due to circular import issues atm
from action_masking.sb3_contrib import (
    InformerWrapper,
    ActionMaskingWrapper,
    ActionProjectionWrapper,
    ActionReplacementWrapper
)
from action_masking.util.safe_region import ControlInvariantSetZonotope, ControlInvariantSetPolytope
from action_masking.util.sets import Zonotope


from action_masking.util.util  import (
    Algorithm,
    ActionSpace,
    Approach,
    ContMaskingMode,
    TransitionTuple,
)
from action_masking.util.zono_safe_input_set import calc_safe_input_set


class Benchmark(ABC):
    """
    Abstract base class for a benchmark.
    """

    def __init__(
        self,
        config: dict,
        use_zonotope: bool = True,
    ):
        """
        Constructor: Initializes the benchmark.
        Args:
            config (dict): Configuration of the benchmark.
            use_zonotope (bool): Use zonotope representation for the environment benchmark.
        """
        self.steps = config.get("steps") # Number of training steps
        self.train_iters = config.get("train_iters") # Number of models per configuration
        self.n_eval_ep = config.get("n_eval_ep", 10) # Number of evaluation episodes per trained model
        self.max_episode_steps = config.get("max_episode_steps") # Max number of steps per episode

        # Set hyperparameters
        self.punishment = config.get("punishment", -0.1)
        self.find_seeds = config.get("find_seeds", False)
        self.safe_center_obs = config.get("safe_center_obs", False)
        self.log_polytope_space = config.get("log_polytope_space", False)

        # self.gravity = config.get("gravity", 9.81)
        # self.K = config.get("K", 0.89/1.4)
        self.dt = config.get("dt", 0.05)
        self.noise_bound = config.get("noise_bound", 0.08)
        self.action_space_area_eq = config.get("action_space_area_eq") # Volume of the action space in the equilibrium state

        self.x_goal = np.array(config.get("x_goal"))
        self.x_lim_low = np.array(config.get("x_lim_low"))
        self.x_lim_high = np.array(config.get("x_lim_high"))
        self.u_space = np.array(config.get("u_space"))

        self.u_low = self.u_space[0]
        self.u_high = self.u_space[1]
        self.u_eq = (self.u_high + self.u_low) / 2
        
        self.allowable_input_set_factors = np.array((self.u_space[1] - self.u_space[0]) / 2).T
        # optionally, disable the noise
        # self.noise_set = np.array([])


        # The current control invariant set is calculated for noise < 0.1.
        
        self.noise = np.array(config.get("noise_vector", None)) * self.noise_bound   # TODO: might adapt to be specific for each dimension so not one bounded value but for each dimension different values
        self.noise_set = np.array(config.get("noise_set", None)) * self.noise_bound   # TODO: might adapt to be specific for each dimension so not one bounded value but for each dimension different values

        self.noise_set_zonotope = Zonotope(G=self.noise_bound*np.eye(self.noise.shape[0]).T,c=np.zeros(self.noise.shape).reshape(-1,1))   # TODO: might adapt to be specific for each dimension so not one bounded value but for each dimension different values

        # If we use a zonotope representation for the environment benchmark
        if use_zonotope:
            self.template_input_set = Zonotope(G=self.allowable_input_set_factors.reshape(-1,1)*np.array(config.get("G")),c=self.u_eq.reshape(-1,1)) # interval of full action space 
           
            # Zonotope safe region
            self.safe_region = ControlInvariantSetZonotope(S_ctrl_csv=config.get("S_ctrl_csv", 'matlab/S_ctrl_LongQuadrotor.csv'),
                                                        S_RCI_csv=config.get("S_RCI_csv", 'matlab/S_RCI_LongQuadrotor.csv'),
                                                        x_goal=self.x_goal,
                                                        x_lim_low=self.x_lim_low,
                                                        x_lim_high=self.x_lim_high)
            
            # Polytope safe region still needed for action replacement and action projection 
            self.safe_region_polytope = ControlInvariantSetPolytope(polytope_csv=None,vertices_csv=config.get("vertices_csv", 'matlab/vertices_LongQuadrotor.csv'),
                                                                    A_pkl=config.get("A_pkl", 'matlab/A_longquadrotor.pkl'),
                                                                    b_pkl=config.get("b_pkl", 'matlab/b_longquadrotor.pkl'),
                                                                    S_ctrl_csv=config.get("S_ctrl_csv", 'matlab/S_ctrl_LongQuadrotor.csv'),
                                                                    S_RCI_csv=config.get("S_RCI_csv", 'matlab/S_RCI_LongQuadrotor.csv'),
                                                                    x_goal=self.x_goal,
                                                                    x_lim_low=self.x_lim_low,
                                                                    x_lim_high=self.x_lim_high
                                                                    )


    @abstractmethod
    def safe_control_fn(self, env, safe_region):
        """ 
        Placeholder method:
        Returns a failsafe action from the RCI controller.
        Args:
            env: environment
            safe_region: safe region
        Returns: failsafe action
        """
        pass
    
    @abstractmethod
    def sampling_fn(self):
        """
        Placeholder method: 
        Sample a safe action.
        """
        pass

    
    def unactuated_dynamics_fn(self, env):
        """
        Return the next state of the environment when no action is applied.
        Noise is disabled.
        # TODO: adapt this won't work for 3d quadrotor anymore because it doesn't use the eq variables -> where should we override in 2d or 3d quadrotor
        Args:
            env: environment
        Returns:
            next_state: next state of the environment
        """
        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        next_state = env.A_d @ (state) # + env.B_d @ (-1* env.u_eq) # TODO: check if @ (-1* env.u_eq) is still needed 
        
        return next_state


    def actuated_dynamics_fn(self, env):
        """
        Returns the effect that the action has on the next state.
        Args:
            env: environment
        """
        return env.B_d
        
        # state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        # return np.array([[
        #     0,
        #     0,
        #     env.dt * env.K * np.sin(state[4]),
        #     env.dt * env.K * np.cos(state[4]),
        #     0,
        #     0],
        # [0,
        #     0,
        #     0,
        #     0,
        #     0,
        #     env.dt * env.n0]]).T
        

    def dynamics_fn(self, env, action):
        """
        Return the vertices of the set of possible next states of the environment when the action is applied.
        Args:
            env: environment
            action: action
        Returns:
            next_states: next states
        """
    
        f = self.unactuated_dynamics_fn(env)
        g = self.actuated_dynamics_fn(env)
        if self.noise_set.size > 0:
            next_states = f + np.dot(g, action) + (env.E_d @ self.noise_set.T).T
        else:
            next_states = f + np.dot(g, action)
        return next_states




    @abstractmethod
    def continuous_safe_space_fn(self, env, safe_region) -> np.ndarray:
        """
        Placeholder method:
        Compute the safe (continuous) action space given the current state.
        Args:
            env: environment
            safe_region: safe region
        Returns:
            vertices: vertices of the set of possible next states
            H, d: halfspaces of the safe action space: H a <= d
        """
        pass


    def continuous_safe_space_fn_masking_zonotope(self, env, safe_region: ControlInvariantSetZonotope) -> Zonotope:
        """
        Compute safe (continuous) action space for masking in zonotope representation.
        Args:
            env: environment
            safe_region: safe region
        Returns:
            Zonotope: zonotope representation of the safe action space
        """
        # A = env.A_d
        # B = env.B_d
        # X = x + env.x_eq - env.A_d @ - env.x_eq - env.B_d @ env.u_eq -> already shift the state by constant parts, point for now Generator matrix is 0
        # XS = RCI set in generator space
        # W = self.noise_set_zonotope

        state = env.get_attr("state")[0] if isinstance(env, DummyVecEnv) else env.state
        X = state # + env.x_eq - env.A_d @ - env.x_eq - env.B_d @ env.u_eq
        XS_c = safe_region.RCI_zonotope.c # - (env.x_eq - env.A_d @ env.x_eq - env.B_d @ env.u_eq).reshape((6,1)) # TODO: check if removing this is correct
        safe_action_set = calc_safe_input_set(
        A=env.A_d,
        B=env.B_d,
        X=Zonotope(
            G=np.eye(6)*0.0,
            c=X.reshape((6, 1))),
        tU=self.template_input_set,
        W=self.noise_set_zonotope.map(env.E_d),
        XS=Zonotope(
            G=safe_region.RCI_zonotope.G,
            c=XS_c),
        UF=Zonotope(G=np.eye(self.u_eq.shape[0])*self.allowable_input_set_factors,c=self.u_eq.reshape((-1,1))),
        )
        return safe_action_set
    
    
    # @abstractmethod
    # def continuous_safe_space_fn_masking_inner_interval(self, env, safe_region):
    #     """
    #     TODO: Check if necessary here
    #     Placeholder method: 
    #     Args:
    #         env: environment
    #         safe_region: safe region

    #     Return the vertices of the set of possible next states of the environment when the action is applied.
    #     """
    #     pass




    def create_env(self, env_name: str, space: ActionSpace, approach: Approach, transition_tuple: TransitionTuple, sampling_seed: int=None, continuous_action_masking_mode: ContMaskingMode=None) ->  DummyVecEnv:
        """
            Create and wrap the environment.

            Args:
                env_name (str): Name of the environment.
                space (ActionSpace): Action space to use. Can be either Discrete or Continuous.
                approach (Approach): Approach to use. Can be either Baseline, Sample, Failsafe, Masking, or Projection.
                TransitionTuple (TransitionTuple): Tuple of transition functions.
                    Can be either Naive, AdaptionPenalty, SafeAction, or Both.
                sampling_seed (int): Seed for sampling.
                continuous_action_masking_mode (ContMasinkgMode): Mode for continuous action masking.
            Returns:
                env: Wrapped environment.
        """
        punishment_fn = None
        generate_wrapper_tuple = False

        # Define reward punishment function / adaption penalty
        if transition_tuple is TransitionTuple.AdaptionPenalty or transition_tuple is TransitionTuple.Both:
            def punishment_fn(env, action, reward, safe_action):
                return reward + self.punishment

        # Wrapper information is necessary for (s, a_phi, s′, r) tuples
        if transition_tuple is TransitionTuple.SafeAction or transition_tuple is TransitionTuple.Both:
            generate_wrapper_tuple = True

        start_state = self.x_goal
        # Create simple pendulum gym environment
        # specify everything that is different from the default here
        if self.noise_set.size > 0 and self.noise is not None:
            env = gym.make(env_name,
                        dt=self.dt,
                        start_state=start_state,
                        randomize_start_range=[0.0, 0.0, 0.0],
                        randomize_env=self.config.get("randomize_env", False),
                        w=self.noise,
                        collision_reward= 0.0,
                        safe_region=self.safe_region,
                        )
        else:
            env = gym.make(env_name,
                        dt=self.dt,
                        start_state=start_state,
                        randomize_start_range=[0.0, 0.0, 0.0],
                        randomize_env=self.config.get("randomize_env", False),
                        collision_reward= 0.0,
                        safe_region=self.safe_region,
                        )
        env = TimeLimit(env, max_episode_steps=self.max_episode_steps)

        if space is ActionSpace.Discrete:
            alter_action_space = gym.spaces.Discrete(100)

            def transform_action_space_fn(action):
                """Transform discrete action to continuous action.
                # TODO: make more general for 3D
                """
                # [0..99] -> [a_min, a_max]
                u_1 = np.clip((action % 10) * ((self.u_high[0]-self.u_low[0])/10) + self.u_low[0], self.u_low[0], self.u_high[0])
                u_2 = np.clip(np.floor(action/10) * ((self.u_high[1]-self.u_low[1])/10) + self.u_low[1], self.u_low[1], self.u_high[1])
                return np.array([u_1, u_2])

            def inv_transform_action_space_fn(u):
                """Transform continuous action to discrete.
                # TODO: make more general for 3D
                """
                # [a_min, a_max] -> [0..99]
                a_1 = np.clip(np.floor((u[0]-self.u_low[0]) / (self.u_high[0]-self.u_low[0]) * 10), 0, 9)
                a_2 = np.clip(np.floor((u[1]-self.u_low[1]) / (self.u_high[1]-self.u_low[1]) * 10), 0, 9)
                return int(a_1 + 10*a_2)
        else:
            alter_action_space = gym.spaces.Box(low=-1, high=1, shape=(self.u_space.shape[1],), dtype=np.float32)

            def transform_action_space_fn(action):
                """Convert action from [-1, 1] to [u_min, u_max]."""
                # [-1,1] -> [u_min, u_max]
                return np.clip(((action + 1)/2) * (self.u_high-self.u_low) + self.u_low, self.u_low, self.u_high)

            def inv_transform_action_space_fn(u):
                """Convert action from [u_min, u_max] to [-1, 1]."""
                # [a_min, a_max] -> [-1,1]
                return np.clip(((u - self.u_low) / (self.u_high-self.u_low)) * 2 - 1, -1, 1)
            
            def transform_action_space_zonotope_fn(action: Union[np.ndarray, List[float], float], safe_space: Zonotope) -> np.ndarray:
                """
                Convert action in generator dimension (number of zonotope alphas) to action in action space and convert to be in [u_min, u_max].
                Args:
                    action: action in generator dimension
                    safe_space: safe space
                Returns:
                    action in action space in [u_min, u_max]
                """
                #TODO: this might be wrong as we have to take the general G and c instead of the safe_space
                action = safe_space.G @ action + np.squeeze(safe_space.c)
                # TODO: check if clipping makes sense this way
                return np.clip(action, self.u_low, self.u_high)

            def inv_transform_action_space_zonotope_fn(u: np.ndarray, safe_space: Zonotope) -> np.ndarray:
                """
                Convert action in action space to action in generator dimension (number of zonotope alphas) and convert to be in [-1, 1].
                Args:
                    u: action in action space
                    safe_space: safe space
                Returns:
                    action in generator dimension in [-1, 1]
                """
                #TODO: this might be wrong as we have to take the general G and c instead of the safe_space
                action = np.linalg.pinv(safe_space.G) @ (u - np.squeeze(safe_space.c))
                # TODO: check if clipping makes sense this way
                return np.clip(action, -1, 1) # np.clip(((u - self.u_low) / (self.u_high-self.u_low)) * 2 - 1, -1, 1)


        # Wrap environment
        if approach is Approach.Baseline:
            env = InformerWrapper(
                env=env,
                alter_action_space=alter_action_space,
                transform_action_space_fn=transform_action_space_fn
            )
        elif approach is Approach.Sample or approach is Approach.FailSafe:
            env = ActionReplacementWrapper(
                env,
                safe_region=self.safe_region_polytope,
                dynamics_fn=self.dynamics_fn,
                safe_control_fn=self.safe_control_fn,
                punishment_fn=punishment_fn,
                alter_action_space=alter_action_space,
                transform_action_space_fn=transform_action_space_fn,
                use_sampling=approach is Approach.Sample,
                continuous_safe_space_fn=self.continuous_safe_space_fn,
                sampling_fn=self.sampling_fn,
                generate_wrapper_tuple=generate_wrapper_tuple,
                inv_transform_action_space_fn=inv_transform_action_space_fn,
                sampling_seed=sampling_seed,
            )
        elif approach is Approach.Masking:
            if env.action_space is gym.spaces.Box and continuous_action_masking_mode is None:
                raise ValueError("action masking mode not set")
            if continuous_action_masking_mode and (continuous_action_masking_mode == ContMaskingMode.Generator or continuous_action_masking_mode == ContMaskingMode.Ray):
                generator_dim = self.template_input_set.G.shape[1]
            else:
                generator_dim = None
            if continuous_action_masking_mode and continuous_action_masking_mode == ContMaskingMode.Interval:
                env = ActionMaskingWrapper(
                    env,
                    safe_region=self.safe_region_polytope,
                    dynamics_fn=self.dynamics_fn,
                    safe_control_fn=self.safe_control_fn,
                    punishment_fn=punishment_fn,
                    continuous_safe_space_fn=self.continuous_safe_space_fn_masking_inner_interval,
                    alter_action_space=alter_action_space,
                    transform_action_space_fn=transform_action_space_fn,
                    generate_wrapper_tuple=generate_wrapper_tuple,
                    inv_transform_action_space_fn=inv_transform_action_space_fn,
                    continuous_action_masking_mode=continuous_action_masking_mode,
                    generator_dim=generator_dim,
                    safe_center_obs=self.safe_center_obs,
                )
            else:
                env = ActionMaskingWrapper(
                    env,
                    safe_region=self.safe_region,
                    dynamics_fn=self.dynamics_fn,
                    safe_control_fn=self.safe_control_fn,
                    punishment_fn=punishment_fn,
                    continuous_safe_space_fn=self.continuous_safe_space_fn_masking_zonotope, 
                    continuous_action_space_fn_polytope=self.continuous_safe_space_fn_masking_inner_interval if self.log_polytope_space else None,
                    safe_region_polytope=self.safe_region_polytope if self.log_polytope_space else None,
                    alter_action_space=alter_action_space,
                    # transform_action_space_fn=transform_action_space_zonotope_fn if continuous_action_masking_mode == ContMaskingMode.Generator else transform_action_space_fn,
                    transform_action_space_fn=transform_action_space_fn if (continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal or continuous_action_masking_mode == ContMaskingMode.Ray)  else transform_action_space_zonotope_fn,
                    generate_wrapper_tuple=generate_wrapper_tuple,
                    inv_transform_action_space_fn=inv_transform_action_space_fn if (continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal or continuous_action_masking_mode == ContMaskingMode.Ray) else inv_transform_action_space_zonotope_fn,
                    # inv_transform_action_space_fn=inv_transform_action_space_zonotope_fn if continuous_action_masking_mode == ContMaskingMode.Generator else inv_transform_action_space_fn,
                    continuous_action_masking_mode=continuous_action_masking_mode,
                    generator_dim=generator_dim,
                    safe_center_obs=self.safe_center_obs,
                )
        elif approach is Approach.Projection:
            env = ActionProjectionWrapper(
                env,
                safe_region=self.safe_region_polytope,
                action_limits=self.u_space,
                dynamics_fn=self.dynamics_fn,
                actuated_dynamics_fn=self.actuated_dynamics_fn,
                unactuated_dynamics_fn=self.unactuated_dynamics_fn,
                noise_set=self.noise_set,
                punishment_fn=punishment_fn,
                alter_action_space=alter_action_space,
                transform_action_space_fn=transform_action_space_fn,
                generate_wrapper_tuple=generate_wrapper_tuple,
                inv_transform_action_space_fn=inv_transform_action_space_fn,
                safe_control_fn=self.safe_control_fn,
                gamma=1,  # Minimal alteration
                safety_margin=0.01
            )

        return DummyVecEnv([lambda: Monitor(env)])

    @abstractmethod
    def run_experiment(self, 
                   alg: Algorithm,
                   policy: BasePolicy,
                   space: ActionSpace,
                   approach: Approach,
                   transition_tuple: TransitionTuple,
                   path: str,
                   continuous_action_masking_mode=None
        ):
        """
        Placeholder method:
        Run the experiment and save the model.
        Args:
            alg (Algorithm): Algorithm to use.
            policy (BasePolicy): Policy to use.
            space (ActionSpace): Action space to use. Can be either Discrete or Continuous.
            approach (Approach): Approach to use. Can be either Baseline, Sample, Failsafe, Masking, or Projection.
            TransitionTuple (TransitionTuple): Tuple of transition functions.
                Can be either Naive, AdaptionPenalty, SafeAction, or Both.
            path (str): Path to save the model.
            continuous_action_masking_mode (ContMasinkgMode): Mode for continuous action masking can be either Generator, Ray or Interval.
        """
        pass
        