import warnings

import gym
import numpy as np
from numpy.random import default_rng
from action_masking.sb3_contrib.common.utils import fetch_fn

class ActionReplacementWrapper(gym.Wrapper):

    """
   :param env: Gym environment
   :param safe_region: Safe region instance
   :param dynamics_fn: Dynamics function
   :param safe_control_fn: Verified fail safe action function
   :param punishment_fn: Reward punishment function
   :param alter_action_space: Alternative gym action space
   :param transform_action_space_fn: Action space transformation function
   :param inv_transform_action_space_fn: Inverse action space transformation function
   :param continuous_safe_space_fn: Safe (continuous) action space function
   :param generate_wrapper_tuple: Generate tuple (wrapper action, environment reward)
   :param sampling_seed: default_rng seed used to sample from the safe action space
   """

    def __init__(self,
                 env: gym.Env,
                 safe_region,
                 dynamics_fn,
                 safe_control_fn,
                 punishment_fn=None,
                 alter_action_space=None,
                 transform_action_space_fn=None,
                 continuous_safe_space_fn=None,
                 generate_wrapper_tuple=False,
                 inv_transform_action_space_fn=None,
                 use_sampling=False,
                 sampling_seed=None):

        super().__init__(env)
        self._safe_region = safe_region
        self._generate_wrapper_tuple = generate_wrapper_tuple
        self._use_sampling = use_sampling
        self._dynamics_fn = fetch_fn(self.env, dynamics_fn)
        self._punishment_fn = fetch_fn(self.env, punishment_fn)
        self._safe_control_fn = fetch_fn(self.env, safe_control_fn)
        self._continuous_safe_space_fn = fetch_fn(self.env, continuous_safe_space_fn)
        self._transform_action_space_fn = fetch_fn(self.env, transform_action_space_fn)
        self._rng = default_rng(seed=sampling_seed)

        if not hasattr(self.env, "action_space"):
            warnings.warn("Environment has no attribute ``action_space``")

        if alter_action_space is not None:
            self.action_space = alter_action_space
            if transform_action_space_fn is None:
                warnings.warn("Set ``alter_action_space`` but no ``transform_action_space_fn``")
            elif generate_wrapper_tuple and inv_transform_action_space_fn is None:
                warnings.warn("``generate_wrapper_tuple`` but no ``inv_transform_action_space_fn``")
            else: self._inv_transform_action_space_fn = fetch_fn(self.env, inv_transform_action_space_fn)

        if not isinstance(self.action_space, (gym.spaces.Discrete, gym.spaces.Box)):
            raise ValueError(f"{type(self.action_space)} not supported")

        if use_sampling:
            if isinstance(self.action_space, gym.spaces.Discrete):
                self._num_actions = self.action_space.n
            else:
                if self._continuous_safe_space_fn is None:
                    raise ValueError(f"{type(self.action_space)} but no ``continuous_safe_space_fn``")

    def step(self, action):

        if isinstance(self.action_space, gym.spaces.Box):
            action = action.item()

        # Optional action transformation
        if self._transform_action_space_fn is not None:
            action = self._transform_action_space_fn(action)

        if not self._dynamics_fn(self.env, action) in self._safe_region:

            if isinstance(self.action_space, gym.spaces.Discrete):

                # Safe (discrete) action space
                safe_space = []
                if self._use_sampling:
                    for i in range(self._num_actions):
                        action = i if (self._transform_action_space_fn is None) \
                            else self._transform_action_space_fn(i)
                        if self._dynamics_fn(self.env, action) in self._safe_region:
                            safe_space.append(action)

                if safe_space:
                    # Fallback to sampled safe action
                    safe_action = self._rng.choice(safe_space)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": safe_action, "fail_safe_action": None}

                else:
                    # Fallback to verified fail-safe control
                    safe_action = self._safe_control_fn(self.env, self._safe_region)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": None, "fail_safe_action": safe_action}

            else:

                # Safe (continuous) action space
                safe_space = self._continuous_safe_space_fn(self.env, self._safe_region) if self._use_sampling else None

                if safe_space is not None:
                    # Fallback to sampled safe action
                    safe_action = self._rng.uniform(*safe_space)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": safe_action, "fail_safe_action": None}

                else:
                    # Fallback to verified fail-safe control
                    safe_action = self._safe_control_fn(self.env, self._safe_region)
                    obs, reward, done, info = self.env.step(safe_action)
                    info["replacement"] = {"env_reward": reward, "sample_action": None, "fail_safe_action": safe_action}

            if self._generate_wrapper_tuple:
                wrapper_action = self._inv_transform_action_space_fn(safe_action)\
                    if self._inv_transform_action_space_fn is not None else safe_action
                info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))

            # Optional reward punishment
            if self._punishment_fn is not None:
                punishment = self._punishment_fn(self.env, action, reward, safe_action)
                info["replacement"]["pun_reward"] = punishment
                reward = punishment
            else:
                info["replacement"]["pun_reward"] = None

        else:

            # Action is safe
            obs, reward, done, info = self.env.step(action)
            info["replacement"] = {"env_reward": reward, "sample_action": None,
                               "fail_safe_action": None, "pun_reward": None}

        info["replacement"]["policy_action"] = action
        return obs, reward, done, info

