from typing import SupportsFloat, Any

import gymnasium as gym

from gymnasium import spaces
from gymnasium.core import Wrapper, ObservationWrapper, ActionWrapper
from gymnasium.wrappers import TransformReward, TimeLimit

import numpy as np


class pitfall_wrapper(Wrapper):
    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        if reward < 0:
            reward = -0.01
            # terminated = True
            # truncated = True

        return observation, reward, terminated, truncated, info


class VizDoomScreenObsWrapper(ObservationWrapper):
    def __init__(self, env):
        super(VizDoomScreenObsWrapper, self).__init__(env)
        self.observation_space = self.observation_space.spaces['screen']

    def observation(self, observation):
        return observation['screen']


class VizDoomAddChanelWrapper(ObservationWrapper):
    def __init__(self, env):
        super(VizDoomAddChanelWrapper, self).__init__(env)

        original_shape = self.observation_space.shape
        new_shape = (1,) + original_shape
        self.observation_space = spaces.Box(low=self.observation_space.low.min(),
                                            high=self.observation_space.high.max(),
                                            shape=new_shape,
                                            dtype=self.observation_space.dtype)

    def observation(self, observation):
        return np.expand_dims(observation, axis=0)


class RemovePickUpActionWrapper(ActionWrapper):
    """
    The wrapper to remove the pickup action from the action space.
    """

    def __init__(self, env):
        super(RemovePickUpActionWrapper, self).__init__(env)
        # define the new action space
        self.action_space = gym.spaces.Discrete(env.action_space.n - 1)

    def action(self, act):
        return act


env_wrapper_list = {"ALE/Frogger-v5": [{"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0 else 0}}],
                    "ALE/Solaris-v5": [{"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0 else 0}},
                                       {"wrapper": TimeLimit, "kwargs": {"max_episode_steps": 4000}}],
                    "ALE/BeamRider-v5": [{"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0 else 0}}],
                    "VizdoomDefendCenter-v0": [
                        {"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0 else 0}}],
                    "VizdoomDefendLine-v0": [
                        {"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0 else 0}}],
                    "HealthGatheringLevel0-v0": [
                        {"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0 else 0}}],
                    "SeekAndSlayLevel0-v0": [
                        {"wrapper": TransformReward, "kwargs": {"f": lambda x: 1 if x > 0.8 else 0}}], }
