import collections
import random
from typing import Any, Dict, List, Sequence, Union

import dm_env
import gymnasium as gym
import numpy as np
import torch
from dm_env import TimeStep, specs
from gymnasium.spaces import Box
# from stable_baselines3.common.vec_env.base_vec_env import VecEnv


class DMGymWrapper(dm_env.Environment):
    def __init__(self, env: gym.Env) -> None:
        super().__init__()
        self._env = env
        self._reset_next_step = True
        self.i = 0

        self._observation_spec = collections.OrderedDict()
        obs_space = self._env.observation_space
        self._observation_spec["observation"] = specs.BoundedArray(
            shape=obs_space.shape,
            dtype=obs_space.dtype,
            minimum=obs_space.low,
            maximum=obs_space.high,
            name="observation",
        )

    def reset(self) -> TimeStep:
        self._reset_next_step = False
        self.i = 0
        obs = self._env.reset()
        return dm_env.restart({"observation": obs})

    def step(self, action) -> TimeStep:
        if self._reset_next_step:
            return self.reset()

        obs, reward, done, info = self._env.step(action)
        self.i += 1
        if done or self.i >= 500:
            self._reset_next_step = True
            return dm_env.termination(reward=reward, observation={"observation": obs})
        else:
            return dm_env.transition(reward=reward, observation={"observation": obs})

    def observation_spec(self):
        return self._observation_spec

    def action_spec(self):
        act_space = self._env.action_space
        return specs.BoundedArray(
            shape=act_space.shape,
            dtype=act_space.dtype,
            minimum=act_space.low,
            maximum=act_space.high,
            name="action",
        )


class NewGymWrapper(gym.Env):
    def __init__(self, env, seed):
        self.env = env
        self.seed_generator = random.Random(seed)

        self.action_space = env.action_space
        self.observation_space = env.observation_space

    def reset(self):
        seed = self.seed_generator.randint(0, 2**32 - 1)
        reset = self.env.reset(seed=seed)
        return reset[0]

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        info["truncated"] = truncated
        terminated = terminated or truncated
        return observation, reward, terminated, info


class PyTorchWrapper(gym.Env):
    def __init__(
        self,
        env: gym.Env,
        device: str = "cpu",
        permute_3d_obs: bool = False,
    ):
        self.env = env
        self.vec_env = False
        self.num_envs = 1
        self.permute_3d_obs = permute_3d_obs

        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.action_high = env.action_space.high
        self.action_low = env.action_space.low
        self.action_space.high = np.ones_like(self.action_high)
        self.action_space.low = -np.ones_like(self.action_low)

        if permute_3d_obs:
            obs_shape = (
                self.observation_space.shape[2],  # type: ignore
                self.observation_space.shape[0],  # type: ignore
                self.observation_space.shape[1],  # type: ignore
            )
        else:
            obs_shape = self.observation_space.shape

        self.is_3d_observation = (
            isinstance(self.observation_space, Box)
            and len(self.observation_space.shape) == 3  # type: ignore
        )
        if self.is_3d_observation and isinstance(self.observation_space, Box):
            self.observation_space = Box(
                low=0,
                high=255,
                shape=obs_shape,
                dtype=np.uint8,
            )

        self.device = device
        self._max_episode_steps = 500

    def step(self, action):
        if isinstance(action, torch.Tensor):
            action = action.detach().cpu().numpy()
        if not self.vec_env:
            action = action.squeeze()
        action = action * (self.action_space.high - self.action_space.low) / 2 + (
            (self.action_space.high + self.action_space.low) / 2
        )
        obs, reward, done, info = self.env.step(action)
        if self.is_3d_observation and self.permute_3d_obs:
            obs = wrap_3d_obs(obs)
        info = {
            k: to_torch(v).to(self.device)
            for k, v in (
                info if isinstance(info, dict) else stack_list_dict(info)
            ).items()
        }
        if "task" not in info.keys():
            info["task"] = torch.zeros((self.num_envs, 1), dtype=torch.int64).to(
                self.device
            )

        if self.vec_env:
            obs = to_torch(obs).to(self.device)
            reward = to_torch(reward).to(self.device)
            done = to_torch(done).to(self.device)
        else:
            obs = expand_dim(to_torch(obs)).to(self.device)
            reward = expand_dim(torch.Tensor([reward])).to(self.device)
            done = expand_dim(torch.Tensor([done])).to(self.device)
        if self.is_3d_observation:
            obs = obs.to(torch.uint8)
        reward = reward.float()
        done = done.bool()
        return (obs, reward, done, info)

    def reset(self) -> torch.Tensor:
        obs = self.env.reset()
        if self.is_3d_observation and self.permute_3d_obs:
            obs = to_torch(wrap_3d_obs(obs))
            obs = obs.to(torch.uint8)
        else:
            obs = to_torch(obs)
        if self.vec_env:
            return obs.to(self.device)
        else:
            return expand_dim(obs).to(self.device)

    def render(self, mode="human"):
        self.env.render(mode=mode)

    def reset_task(self, task_id: Union[List[int], int]):
        pass


class BraxWrapper:
    def __init__(self, env: gym.Env, device: str = "cpu", num_envs: int = 1) -> None:
        self.env = env
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.device = device
        self.num_envs = num_envs

        self._max_episode_steps = 1000

    def step(self, action):
        return self.env.step(action.to(self.device))

    def reset(self):
        return self.env.reset()

    def render(self, mode="human"):
        self.env.render(mode=mode)

    def reset_task(self, task_id):
        pass


def wrap_3d_obs(obs):
    obs = obs.transpose(2, 0, 1)
    return obs


def expand_dim(x: torch.Tensor):
    return x.unsqueeze(0)


def stack_list_dict(d: List[Dict[str, np.ndarray]]):
    """
    Stack a list of dicts of numpy arrays into a single dict of numpy arrays.
    """
    stacked = {}
    for k in d[0].keys():
        stacked[k] = np.stack([d_[k] for d_ in d])
    return stacked


def to_torch(x: Union[Any, np.ndarray]):
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x.copy()).float()
    elif isinstance(x, Sequence):
        return torch.Tensor(x)
    elif isinstance(x, torch.Tensor):
        return x
    else:
        return torch.Tensor([x])
