"""Wrapper that stacks frames."""
from collections import deque
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Dict
from collections import defaultdict


class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
    def __init__(
        self,
        env: gym.Env,
        num_stack: int,
    ):
        gym.utils.RecordConstructorArgs.__init__(
            self, num_stack=num_stack,
        )
        gym.ObservationWrapper.__init__(self, env)

        self.num_stack = num_stack
        self.frames = deque(maxlen=num_stack)

        low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
        high = np.repeat(
            self.observation_space.high[np.newaxis, ...], num_stack, axis=0
        )
        self.observation_space = Box(
            low=low, high=high, dtype=self.observation_space.dtype
        )

    def observation(self, observation):
        """Converts the wrappers current frames to lazy frames.

        Args:
            observation: Ignored

        Returns:
            :class:`LazyFrames` object for the wrapper's frame buffer,  :attr:`self.frames`
        """
        assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
        return np.array(self.frames)

    def step(self, action):
        """Steps through the environment, appending the observation to the frame buffer.

        Args:
            action: The action to step through the environment with

        Returns:
            Stacked observations, reward, terminated, truncated, and information from the environment
        """
        observation, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(observation)
        return self.observation(None), reward, terminated, truncated, info

    def reset(self, **kwargs):
        """Reset the environment with kwargs.

        Args:
            **kwargs: The kwargs for the environment reset

        Returns:
            The stacked observations
        """
        obs, info = self.env.reset(**kwargs)

        [self.frames.append(obs*0) for _ in range(self.num_stack)]
        self.frames.append(obs)

        return self.observation(None), info


class MaskedFrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
    def __init__(
        self,
        env: gym.Env,
        num_stack: int,
        mod_obs: bool = True,
        use_multidiscrete: bool = False,
    ):
        gym.utils.RecordConstructorArgs.__init__(
            self,
            num_stack=num_stack,
            mod_obs=mod_obs,
            use_multidiscrete=use_multidiscrete,
        )
        gym.ObservationWrapper.__init__(self, env)

        assert num_stack > 0, "num_stack should be greater than 0"
        self.num_stack = num_stack
        self.mod_obs = mod_obs
        self.use_multidiscrete = use_multidiscrete

        if self.mod_obs:
            low = np.repeat(
                self.observation_space.low[np.newaxis, ...], num_stack, axis=0
            )
            high = np.repeat(
                self.observation_space.high[np.newaxis, ...], num_stack, axis=0
            )
            self.observation_space = Box(
                low=low, high=high, dtype=self.observation_space.dtype
            )

        self.env_action_space = env.action_space
        if self.use_multidiscrete and isinstance(self.env_action_space, Discrete):
            # Separate heads for env action and mask index
            self.action_space = MultiDiscrete([self.env_action_space.n, num_stack])
        elif isinstance(self.env_action_space, MultiDiscrete):
            # Single discrete space: env.n * num_stack
            self.action_space = MultiDiscrete(np.append(self.env_action_space.nvec, num_stack))
        elif isinstance(self.env_action_space, Discrete):
            # Single discrete space: env.n * num_stack
            self.action_space = Discrete(self.env_action_space.n * num_stack)
        else:
            # Continuous case: append scalar mask
            low = np.concatenate([
                self.env_action_space.low,
                np.zeros(1, dtype=np.uint8)+self.env_action_space.low[0],
            ])
            high = np.concatenate([
                self.env_action_space.high,
                np.zeros(1, dtype=np.uint8)+self.env_action_space.high[0],
            ])
            self.action_space = Box(low=low, high=high, dtype=self.env_action_space.dtype)

    def join_action_mask(self, action, mask):
        if self.use_multidiscrete and isinstance(self.action_space, MultiDiscrete):
            # Pack into [env_action, mask_index]
            return np.array([action, mask], dtype=np.int64)
        if isinstance(self.action_space, Discrete):
            # Flatten: action + mask * (orig_n)
            orig_n = self.action_space.n // self.num_stack
            return int(action) + int(mask) * orig_n
        # Continuous: concatenate env action and one-hot mask
        mask_one_hot = np.zeros(self.num_stack, dtype=np.uint8)
        mask_one_hot[int(mask)] = 1
        return np.concatenate([action, mask_one_hot])

    def split_action_mask(self, action):
        # Unpack from [env_action, mask_index]
        if self.use_multidiscrete and isinstance(self.env_action_space, Discrete):
            env_action = int(action[0])
            mask = int(action[1])
            return env_action, mask
        elif isinstance(self.action_space, MultiDiscrete):
            env_action = action[:-1]
            mask = int(action[-1])
            return env_action, mask
        elif isinstance(self.action_space, Discrete):
            # Unflatten
            orig_n = self.action_space.n // self.num_stack
            env_action = int(action) % orig_n
            mask = int(action) // orig_n
            return env_action, mask
        else:
            # Continuous: last entries are one-hot mask
            action = np.asarray(action)
            env_action = action[:-1]
            mask = action[-1:]
            mask_space = np.linspace(self.env_action_space.low[0],self.env_action_space.high[0],self.num_stack)
            # print(mask_space, mask, np.abs(mask_space-mask), np.abs(mask_space-mask).argmin())
            mask = np.abs(mask_space-mask).argmin()
            return env_action, mask

    def observation(self, observation):
        assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
        return np.array(self.frames)

    def step(self, action):
        info = dict(pre_frames=self.frames.copy())
        action, mask = self.split_action_mask(action)
        # Removing selected frame or resetting if mask == num_stack
        if mask == self.num_stack:
            self.frames = [self.init_obs * 0 for _ in range(self.num_stack - 1)]
        else:
            del self.frames[mask]

        observation, reward, terminated, truncated, env_info = self.env.step(action)
        info.update(env_info)
        info["obs"] = observation.copy()
        self.frames.append(observation)

        state = self.observation(observation)
        if self.mod_obs:
            observation = state
        return observation, reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.init_obs = obs * 0
        self.frames = [self.init_obs * 0 for _ in range(self.num_stack - 1)]
        self.frames.append(obs)
        state = self.observation(obs)
        if self.mod_obs:
            obs = state
        return obs, info

class DemirFrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
    def __init__(
        self,
        env: gym.Env,
        num_stack: int,
        mod_obs: bool = True,
        use_multidiscrete: bool = False,
        intrinsic_rewards: bool = False,
    ):
        gym.utils.RecordConstructorArgs.__init__(
            self,
            num_stack=num_stack,
            mod_obs=mod_obs,
            use_multidiscrete=use_multidiscrete,
            intrinsic_rewards=intrinsic_rewards,
        )
        gym.ObservationWrapper.__init__(self, env)

        assert num_stack > 0, "num_stack should be greater than 0"
        self.num_stack = num_stack
        self.mod_obs = mod_obs
        self.use_multidiscrete = use_multidiscrete
        self.frames = deque(maxlen=num_stack)
        self.intrinsic_rewards = intrinsic_rewards
        self.beta = 1

        if self.mod_obs:
            low = np.repeat(
                self.observation_space.low[np.newaxis, ...], num_stack, axis=0
            )
            high = np.repeat(
                self.observation_space.high[np.newaxis, ...], num_stack, axis=0
            )
            self.observation_space = Box(
                low=low, high=high, dtype=self.observation_space.dtype
            )

        self.env_action_space = env.action_space
        if self.use_multidiscrete and isinstance(self.env_action_space, Discrete):
            # Separate heads for env action and mask index
            self.action_space = MultiDiscrete([self.env_action_space.n, 2])
        elif isinstance(self.env_action_space, Discrete):
            # Single discrete space: env.n * num_stack
            self.action_space = Discrete(self.env_action_space.n * 2)
        elif isinstance(self.env_action_space, Box):
            # # Continuous case: Separate heads for env action and mask index
            # self.action_space = Dict({
            #     "env_action_space": self.env_action_space, 
            #     "memory_action_space": Discrete(2)
            # })
            # Continuous case: append scalar mask
            low = np.concatenate([
                self.env_action_space.low,
                np.zeros(1, dtype=np.uint8)+self.env_action_space.low[0],
            ])
            high = np.concatenate([
                self.env_action_space.high,
                np.zeros(1, dtype=np.uint8)+self.env_action_space.high[0],
            ])
            self.action_space = Box(low=low, high=high, dtype=self.env_action_space.dtype)
        else:
            raise ValueError(f"Env action space not supported: {self.env_action_space}")

    def join_action_mask(self, action, mask):
        if self.use_multidiscrete and isinstance(self.action_space, MultiDiscrete):
            # Pack into [env_action, mask_index]
            return np.array([action, mask], dtype=np.int64)
        if isinstance(self.action_space, Discrete):
            # Flatten: action + mask * (orig_n)
            orig_n = self.action_space.n // 2
            return int(action) + int(mask) * orig_n
        # Continuous: concatenate env action and one-hot mask
        mask_one_hot = np.zeros(2, dtype=np.uint8)
        mask_one_hot[int(mask)] = 1
        return np.concatenate([action, mask_one_hot])

    def split_action_mask(self, action):
        # Unpack from [env_action, mask_index]
        if self.use_multidiscrete and isinstance(self.env_action_space, Discrete):
            env_action = int(action[0])
            mask = int(action[1])
            return env_action, mask
        elif isinstance(self.action_space, MultiDiscrete):
            env_action = action[:-1]
            mask = int(action[-1])
            return env_action, mask
        elif isinstance(self.action_space, Discrete):
            # Unflatten
            orig_n = self.action_space.n // 2
            env_action = int(action) % orig_n
            mask = int(action) // orig_n
            return env_action, mask
        else:
            # Continuous: last entries are one-hot mask
            action = np.asarray(action)
            env_action = action[:-1]
            mask = action[-1:]
            mask_space = np.linspace(self.env_action_space.low[0],self.env_action_space.high[0],2)
            # print(mask_space, mask, np.abs(mask_space-mask), np.abs(mask_space-mask).argmin())
            mask = np.abs(mask_space-mask).argmin()
            return env_action, mask

    def observation(self, observation):
        assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
        return np.array(self.frames)

    def step(self, action):
        info = dict(pre_frames=self.frames.copy())
        action, mask = self.split_action_mask(action)
        if mask == 0:
            self.memory_used += 1
            if self.memory_used >= self.num_stack:
                self.memory_used = self.num_stack
                del self.frames[0]
        else:
            del self.frames[-1]

        observation, reward, terminated, truncated, env_info = self.env.step(action)
        info.update(env_info)
        info["obs"] = observation.copy()
        self.frames.append(observation)
        self.n[tuple(observation)] += 1
        
        if self.intrinsic_rewards:
            if self.memory_used >= self.num_stack: self.h[tuple(observation)] += 1
            # reward = reward + self.beta*(sum([(1-self.n[tuple(obs)]/sum(list(self.n.values())))**self.h[tuple(obs)] for obs in self.frames])/(self.memory_used+1) - 1)
            reward = reward + self.beta*(sum([(1-self.n[tuple(obs)]/sum(list(self.n.values()))) for obs in self.frames]) - self.num_stack-1)

        state = self.observation(observation)
        if self.mod_obs:
            observation = state
        return observation, reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.n = defaultdict(lambda: 0)
        self.h = defaultdict(lambda: 0)

        self.n[tuple(obs)] += 1
        self.h[tuple(obs)] += 1

        self.memory_used = 0
        self.init_obs = obs * 0
        [self.frames.append(self.init_obs) for _ in range(self.num_stack)]
        self.frames.append(obs)
        state = self.observation(obs)
        if self.mod_obs:
            obs = state
        return obs, info

class FlatObs(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = Box(low=self.observation_space.low.min(), 
                                     high=self.observation_space.high.max(), 
                                     shape=self.observation(self.observation_space.sample()).shape,
                                     dtype=self.observation_space.dtype
        )
        
    def observation(self, observation):
        return observation.flatten()
    

class RGBObs(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = Box(low=self.observation_space.low.min(), 
                                     high=self.observation_space.high.max(), 
                                     shape=self.observation(self.observation_space.sample()).shape,
                                     dtype=self.observation_space.dtype
        )
        
    def observation(self, observation):
        if len(observation.shape) == 4: return np.concatenate(observation, axis=1)
        return observation

class TupleObs(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        gym.ObservationWrapper.__init__(self, env)
        
    def observation(self, observation): 
        return tuple(observation.flatten())


class PartialObsGoal(gym.ObservationWrapper):
    def __init__(
        self,
        env: gym.Env,
        visible_goal_steps = 2,
        fetchhideblock = True,
    ):
        gym.ObservationWrapper.__init__(self, env)
        self.steps = 0
        self.visible_goal_steps = visible_goal_steps
        self.fetchhideblock = fetchhideblock

    def step(self, action):
        self.steps += 1
        observation, reward, terminated, truncated, info = self.env.step(action)
        observation["observation"][5:] *= 0
        if self.steps >= self.visible_goal_steps:
            observation["desired_goal"] *= 0
            if self.steps >= self.visible_goal_steps*2:
                observation["observation"][3:6] *= 0
        return observation, reward, terminated, truncated, info

    def reset(self, *args, **kwargs):
        self.steps = 0
        observation, info = self.env.reset(*args, **kwargs)
        observation["observation"][5:] *= 0
        if self.steps >= self.visible_goal_steps:
            if "desired_goal" in observation:
                observation["desired_goal"] *= 0
            else:
                observation[-self.goal_dims:] = 0
        return observation, info