import warnings

import gym
import numpy as np
from provably_safe_benchmark.sb3_contrib.common.utils import fetch_fn


class ActionMaskingWrapper(gym.Wrapper):
    """
    :param env: Gym environment
    :param safe_region: Safe region instance
    :param dynamics_fn: Dynamics function
    :param safe_control_fn: Verified fail safe action function
    :param punishment_fn: Reward punishment function
    :param alter_action_space: Alternative gym action space
    :param transform_action_space_fn: Action space transformation function
    :param continuous_safe_space_fn: Safe (continuous) action space function
    :param generate_wrapper_tuple: Generate tuple (wrapper action, environment reward)
    :param inv_transform_action_space_fn: Inverse action space transformation function
    """

    def __init__(self,
                 env: gym.Env,
                 safe_region,
                 dynamics_fn,
                 safe_control_fn,
                 punishment_fn=None,
                 alter_action_space=None,
                 generate_wrapper_tuple=False,
                 transform_action_space_fn=None,
                 continuous_safe_space_fn=None,
                 inv_transform_action_space_fn=None):

        super().__init__(env)

        self._mask = None
        self._safe_region = safe_region
        self._dynamics_fn = fetch_fn(self.env, dynamics_fn)
        self._generate_wrapper_tuple = generate_wrapper_tuple
        self._punishment_fn = fetch_fn(self.env, punishment_fn)
        self._safe_control_fn = fetch_fn(self.env, safe_control_fn)
        self._continuous_safe_space_fn = fetch_fn(self.env, continuous_safe_space_fn)
        self._transform_action_space_fn = fetch_fn(self.env, transform_action_space_fn)

        if not hasattr(self.env, "action_space"):
            warnings.warn("Environment has no attribute ``action_space``")

        if alter_action_space is not None:
            self.action_space = alter_action_space
            if transform_action_space_fn is None and isinstance(self.action_space, gym.spaces.Discrete):
                warnings.warn("Set ``alter_action_space`` but no ``transform_action_space_fn``")
            elif generate_wrapper_tuple and inv_transform_action_space_fn is None:
                warnings.warn("``generate_wrapper_tuple`` but no ``inv_transform_action_space_fn``")
            else:
                self._inv_transform_action_space_fn = fetch_fn(self.env, inv_transform_action_space_fn)

        if not isinstance(self.action_space, (gym.spaces.Discrete, gym.spaces.Box)):
            raise ValueError(f"{type(self.action_space)} not supported")

        if isinstance(self.action_space, gym.spaces.Discrete):
            # Extend action space with auxiliary action
            self._num_actions = self.action_space.n + 1
            self.action_space = gym.spaces.Discrete(self._num_actions)

        else:
            if self._continuous_safe_space_fn is None:
                raise ValueError(f"{type(self.action_space)} but no ``continuous_safe_space_fn``")

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        if isinstance(self.action_space, gym.spaces.Discrete):
            self._mask = self._discrete_mask()
        return obs

    def _discrete_mask(self):

        mask = np.zeros(self._num_actions, dtype=bool)
        for i in range(self._num_actions - 1):
            action = i if (self._transform_action_space_fn is None)\
                else self._transform_action_space_fn(i)
            if self._dynamics_fn(self.env, action) in self._safe_region:
                mask[i] = True

        if not mask.any():
            mask[-1] = True

        return mask

    def action_masks(self) -> np.ndarray:
        return self._mask

    def step(self, action):
        # Discrete
        if isinstance(self.action_space, gym.spaces.Discrete):
            if self._mask[-1]:
                # Fallback to verified fail-safe control
                safe_action = self._safe_control_fn(self.env, self._safe_region)
                obs, reward, done, info = self.env.step(safe_action)
                info["masking"] = {"policy_action": None, "env_reward": reward,
                                   "fail_safe_action": safe_action, "safe_space": None}

                # Optional reward punishment
                if self._punishment_fn is not None:
                    punishment = self._punishment_fn(self.env, safe_action, reward, self._mask)
                    info["masking"]["pun_reward"] = punishment
                    reward = punishment
                else:
                    info["masking"]["pun_reward"] = None
            else:
                # Optional action transformation
                if self._transform_action_space_fn is not None:
                    action = self._transform_action_space_fn(action)
                # Policy action is safe
                obs, reward, done, info = self.env.step(action)
                info["masking"] = {"policy_action": action, "env_reward": reward,
                                   "fail_safe_action": None, "pun_reward": None, "safe_space": self._mask[:-1]}
            # Compute next mask
            self._mask = self._discrete_mask()
        # Continuous
        else:
            # Safe (continuous) action space function
            safe_space = self._continuous_safe_space_fn(self.env, self._safe_region)
            if safe_space is None:
                # Fallback to verified fail-safe control
                safe_action = self._safe_control_fn(self.env, self._safe_region)
                obs, reward, done, info = self.env.step(safe_action)
                info["masking"] = {"policy_action": None, "env_reward": reward, "fail_safe_action": safe_action}
                if self._generate_wrapper_tuple:
                    wrapper_action = self._inv_transform_action_space_fn(safe_action) \
                        if self._inv_transform_action_space_fn is not None else safe_action
                    info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))
                # Optional reward punishment
                if self._punishment_fn is not None:
                    punishment = self._punishment_fn(
                        env=self.env,
                        action=action,
                        reward=reward,
                        safe_action=safe_action
                    )
                    info["masking"]["pun_reward"] = punishment
                    reward = punishment
                else:
                    info["masking"]["pun_reward"] = None
            else:
                # Scale policy action
                if len(safe_space.shape) == 1:
                    scale = (safe_space[1] - safe_space[0]) / (self.action_space.high - self.action_space.low)
                    action = (scale * (action - self.action_space.low) + safe_space[0]).item()
                else:
                    scale = (safe_space[:, 1] - safe_space[:, 0]) / (self.action_space.high - self.action_space.low)
                    action = scale * (action - self.action_space.low) + safe_space[:, 0]
                #  prev_state = self.env.state
                obs, reward, done, info = self.env.step(action)
                info["masking"] = {"policy_action": action, "env_reward": reward,
                                   "fail_safe_action": None, "pun_reward": None}

                if self._generate_wrapper_tuple:
                    wrapper_action = self._inv_transform_action_space_fn(action) \
                        if self._inv_transform_action_space_fn is not None else action
                    info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))

            info["masking"]["safe_space"] = safe_space

        return obs, reward, done, info
