import d4rl.offline_env
import gymnasium as gym
import gym as classic_gym
import numpy as np
from inspect import getfullargspec
from rl.utils.base_wrapper import NormalizedGymnasiumBoxEnv
from rl.utils.replay_buffer import ReplayBuffer
import warnings

HIGH_VERSION_GYM = classic_gym.__version__ > '0.24'


class SpaceParser(object):

    @staticmethod
    def discrete(space: classic_gym.spaces.Discrete):
        n = space.n
        dtype = space.dtype
        return gym.spaces.Discrete(n, dtype)

    @staticmethod
    def box(space: classic_gym.spaces.Box,
            cast_float32: bool = True):
        low = space.low
        high = space.high
        shape = space.shape
        if not cast_float32:
            dtype = space.dtype
        else:
            dtype = np.float32
        return gym.spaces.Box(low, high, shape=shape, dtype=dtype)

    @staticmethod
    def multi_discrete(space: classic_gym.spaces.MultiDiscrete):
        nvec = space.nvec
        dtype = space.dtype
        return gym.spaces.MultiDiscrete(nvec=nvec, dtype=dtype)

    @staticmethod
    def multi_binary(space: classic_gym.spaces.MultiBinary):
        n = space.n
        dtype = space.dtype
        return gym.spaces.MultiBinary(n=n, dtype=dtype)

    @staticmethod
    def tuple(space: classic_gym.spaces.Tuple):
        gym.spaces.Tuple([type_parser[type(s)](s) for s in space.spaces])

    @staticmethod
    def dict(space: classic_gym.spaces.Dict):
        return gym.spaces.Dict({k: type_parser[type(v)](v) for k, v in space.spaces})

    @staticmethod
    def parse(space, cast_float32: bool = True):
        try:
            parser = type_parser[type(space)]
            spec = getfullargspec(parser)
            if cast_float32 in spec.args:
                return parser(space, cast_float32)
            else:
                return parser(space)
        except KeyError:
            raise NotImplementedError(f"Type {type(space)} is not implemented yet.")


type_parser = {
    classic_gym.spaces.Discrete: SpaceParser.discrete,
    classic_gym.spaces.Box: SpaceParser.box,

    classic_gym.spaces.MultiDiscrete: SpaceParser.multi_discrete,
    classic_gym.spaces.MultiBinary: SpaceParser.multi_binary,
    classic_gym.spaces.Tuple: SpaceParser.tuple,
    classic_gym.spaces.Dict: SpaceParser.dict,
}


class GymnasiumWrapper(gym.Env):
    render_mode: str

    def __init__(self,
                 wrapped_env,
                 *,
                 timeout_key: str = 'TimeLimit.truncated',
                 **kwargs,
                 ):
        self.wrapped = wrapped_env
        self.timeout_key = timeout_key
        try:
            self.observation_space = SpaceParser.parse(self.wrapped.observation_space)
            self.action_space = SpaceParser.parse(self.wrapped.action_space)
        except KeyError:
            raise NotImplementedError

    @property
    def unwrapped(self):
        return self.wrapped

    def render(self, render_mode: str = 'human'):
        return self.wrapped.render(mode=render_mode)

    def reset(self, *, seed=None, options=None):
        if HIGH_VERSION_GYM:
            return self.wrapped.reset(seed=seed, options=options)
        else:
            if seed is not None:
                warnings.filterwarnings("ignore", category=DeprecationWarning)
                self.wrapped.seed(seed)
            if options is not None:
                raise NotImplementedError("Setting option is not implemented yet.")
            return self.wrapped.reset(), {}

    def step(self, actions):
        if HIGH_VERSION_GYM:
            observation, reward, done, timeout, info = self.wrapped.step(actions)
            return observation, reward, done, timeout, info
        else:
            observation, reward, done, info = self.wrapped.step(actions)
            if self.timeout_key in info.keys():
                timeout = info.pop(self.timeout_key)
            else:
                timeout = False
            return observation, reward, done, timeout, info


class D4RLGymnasiumEnv(GymnasiumWrapper):
    def __init__(self, env_id, seed: int = 42):
        self.np_rng = np.random.default_rng(seed)
        import d4rl
        super().__init__(classic_gym.make(env_id))

    def reset(self, *, seed=None, options=None):
        if seed is not None:
            self.np_rng = np.random.default_rng(seed)

        seed = int(self.np_rng.integers(0, 2 ** 31 - 1))
        return super().reset(seed=seed)

    def get_dataset(self):
        return self.unwrapped.get_dataset()


class D4RLPreprocessor(object):
    def __init__(self,
                 env_id: str,
                 normalize_obs: bool = True,
                 normalize_reward: bool = False,
                 terminate_on_end=True,
                 seed: int = 42,
                 ):
        self.gymnasium_env = D4RLGymnasiumEnv(env_id)
        self.normalize_reward = normalize_reward
        self.d4rl_gym_env = self.gymnasium_env.wrapped
        self.seed = seed
        if hasattr(self.d4rl_gym_env, 'get_normalized_score'):
            self.normalized_score = self.d4rl_gym_env.get_normalized_score
        else:
            self.normalized_score = None

        self.q_leanrning_dataset = d4rl.qlearning_dataset(self.d4rl_gym_env,
                                                          terminate_on_end=terminate_on_end)
        self.obs_mean = self.q_leanrning_dataset['observations'].mean(axis=0, keepdims=True)

        self.obs_std = self.q_leanrning_dataset['observations'].std(axis=0, keepdims=True).clip(1e-12)

        if normalize_obs:
            def normalize(x, mean, std):
                return (x - mean) / (std.clip(1e-12, ))
            self.q_leanrning_dataset['observations'] = normalize(self.q_leanrning_dataset['observations'],
                                                                 self.obs_mean, self.obs_std)
            self.q_leanrning_dataset['next_observations'] = normalize(self.q_leanrning_dataset['next_observations'],
                                                                      self.obs_mean, self.obs_std)
            self.gymnasium_env = NormalizedGymnasiumBoxEnv(self.gymnasium_env,
                                                           obs_mean=self.obs_mean, obs_std=self.obs_std)

    @property
    def env(self):
        return self.gymnasium_env

    def get_replay_buffer(self) -> ReplayBuffer:
        return ReplayBuffer.from_qlearning_dataset(self.q_leanrning_dataset,
                                                   normalize_reward=self.normalize_reward,
                                                   seed=self.seed
                                                   )
