import warnings

import gym
import numpy as np
from numpy.random import default_rng
from action_masking.sb3_contrib.common.utils import fetch_fn


class ActionReplacementWrapper(gym.Wrapper):
    """Action replacement wrapper for safe control.
   :param env: Gym environment
   :param safe_region: Safe region instance
   :param dynamics_fn: Dynamics function
   :param safe_control_fn: Verified fail safe action function
   :param punishment_fn: Reward punishment function
   :param alter_action_space: Alternative gym action space
   :param transform_action_space_fn: Action space transformation function
   :param inv_transform_action_space_fn: Inverse action space transformation function
   :param continuous_safe_space_fn: Safe (continuous) action space function
   :param sampling_fn: Function to sample an action from the safe action space
   :param generate_wrapper_tuple: Generate tuple (wrapper action, environment reward)
   :param sampling_seed: default_rng seed used to sample from the safe action space
   """

    def __init__(self,
                 env: gym.Env,
                 safe_region,
                 dynamics_fn,
                 safe_control_fn,
                 punishment_fn=None,
                 alter_action_space=None,
                 transform_action_space_fn=None,
                 continuous_safe_space_fn=None,
                 sampling_fn=None,
                 generate_wrapper_tuple=False,
                 inv_transform_action_space_fn=None,
                 use_sampling=False,
                 sampling_seed=None):

        super().__init__(env)
        self._safe_region = safe_region
        self._generate_wrapper_tuple = generate_wrapper_tuple
        self._use_sampling = use_sampling
        self._dynamics_fn = fetch_fn(self.env, dynamics_fn)
        self._punishment_fn = fetch_fn(self.env, punishment_fn)
        self._safe_control_fn = fetch_fn(self.env, safe_control_fn)
        self._continuous_safe_space_fn = fetch_fn(self.env, continuous_safe_space_fn)
        self._sampling_fn = fetch_fn(self.env, sampling_fn)
        self._transform_action_space_fn = fetch_fn(self.env, transform_action_space_fn)
        self._sampling_seed = sampling_seed
        self._rng = default_rng(seed=self._sampling_seed)

        if not hasattr(self.env, "action_space"):
            warnings.warn("Environment has no attribute ``action_space``")

        if alter_action_space is not None:
            self.action_space = alter_action_space
            if transform_action_space_fn is None:
                warnings.warn("Set ``alter_action_space`` but no ``transform_action_space_fn``")
            elif generate_wrapper_tuple and inv_transform_action_space_fn is None:
                warnings.warn("``generate_wrapper_tuple`` but no ``inv_transform_action_space_fn``")
            else:
                self._inv_transform_action_space_fn = fetch_fn(self.env, inv_transform_action_space_fn)

        if not isinstance(self.action_space, (gym.spaces.Discrete, gym.spaces.Box)):
            raise ValueError(f"{type(self.action_space)} not supported")

        if use_sampling:
            if isinstance(self.action_space, gym.spaces.Discrete):
                self._num_actions = self.action_space.n
            else:
                if self._continuous_safe_space_fn is None:
                    raise ValueError(f"{type(self.action_space)} but no ``continuous_safe_space_fn``")

    def step(self, action):

        # Optional action transformation
        if self._transform_action_space_fn is not None:
            action = self._transform_action_space_fn(action)
        # Check if action is safe
        if not self._safe_region.contains(self._dynamics_fn(self.env, action), -1e-8):
            if isinstance(self.action_space, gym.spaces.Discrete):
                # Safe (discrete) action space
                safe_space = []
                if self._use_sampling:
                    for i in range(self._num_actions):
                        action = i if (self._transform_action_space_fn is None) \
                            else self._transform_action_space_fn(i)
                        if self._dynamics_fn(self.env, action) in self._safe_region:
                            safe_space.append(action)

                if safe_space:
                    # Fallback to sampled safe action
                    safe_action = self._rng.choice(safe_space)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": safe_action, "fail_safe_action": None}

                else:
                    # Fallback to verified fail-safe control
                    safe_action = self._safe_control_fn(self.env, self._safe_region)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": None, "fail_safe_action": safe_action}

            else:
                # Safe (continuous) action space
                safe_space = self._continuous_safe_space_fn(self.env, self._safe_region) if self._use_sampling else None

                if safe_space is not None:
                    # Fallback to sampled safe action
                    safe_action = self._sampling_fn(safe_space, self._rng)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": safe_action, "fail_safe_action": None}

                else:
                    # Fallback to verified fail-safe control
                    safe_action = self._safe_control_fn(self.env, self._safe_region)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": None, "fail_safe_action": safe_action}

            if self._generate_wrapper_tuple:
                wrapper_action = self._inv_transform_action_space_fn(safe_action)\
                    if self._inv_transform_action_space_fn is not None else safe_action
                info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))

            # Optional reward punishment
            if self._punishment_fn is not None:
                punishment = self._punishment_fn(self.env, action, reward, safe_action)
                info["replacement"]["pun_reward"] = punishment
                reward = punishment
            else:
                info["replacement"]["pun_reward"] = None
        else:
            # Action is safe
            # prev_state = self.env.state
            obs, reward, done, info = self.env.step(action)
            info["replacement"] = {"env_reward": reward,
                                   "sample_action": None,
                                   "fail_safe_action": None,
                                   "pun_reward": None}

        info["replacement"]["policy_action"] = action
        assert not np.any(np.isnan(obs)), "NaN in observation"
        return obs, reward, done, info
