from typing import Sequence, Optional, Any, Type, Union
import torch
import gym
from gym.spaces import Box, Tuple
from gym.vector.utils.spaces import batch_space_custom
import numpy as np


class TorchBox(Box):
    def __init__(
        self,
        low: Union[float, Sequence[float]],
        high: Union[float, Sequence[float]],
        shape: Optional[Any] = ...,
        dtype: Optional[Type[Any]] = ...,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        super().__init__(low, high, shape, dtype)
        self.device = device

    def sample(self):
        sample = super().sample()
        sample = torch.from_numpy(sample)
        sample = sample.type(torch.float32)
        sample = sample.to(self.device)
        return sample


class TorchTupleBox(Tuple):
    def __init__(self, spaces: Sequence[TorchBox]):
        super().__init__(spaces)
        for space in spaces:
            if not isinstance(space, TorchBox):
                raise ValueError("All spaces should be `TorchBox`")

    def sample(self) -> torch.Tensor:
        tuple_torch_sample = super().sample()
        sample = torch.stack(tuple_torch_sample)
        return sample


def batch_torchbox_space(space: TorchBox, n=1):
    return TorchTupleBox(tuple(space for _ in range(n)))


class VecBoxTorchWrapper(gym.Wrapper):
    """Torch wrapper for Box envs.

    Args:
        env (gym.Env): Gym env with spaces `Box`
        device (torch.device):  Explicit

    Raises:
        ValueError: If observation and action space is not `Box`
    """

    def __init__(self, env: gym.Env, device: torch.device) -> None:
        super().__init__(env)
        self.is_vector_env = getattr(env, "is_vector_env", False)
        self.num_envs = getattr(env, "num_envs", 1)
        if self.is_vector_env:
            env_observation_space = env.observation_space
            env_action_space = env.action_space[0]
        else:
            env_observation_space = env.observation_space
            env_action_space = env.action_space

        if not isinstance(env_action_space, Box) and not isinstance(
            env_observation_space, Box
        ):
            raise ValueError("Env observation and action spaces should be Box")

        self.device = device
        self.observation_space = TorchBox(
            low=env_observation_space.low,  # type: ignore
            high=env_observation_space.high,  # type: ignore
            shape=env_observation_space.shape,  # type: ignore
            dtype=env_observation_space.dtype,  # type: ignore
            device=device,
        )
        self.action_space = TorchBox(
            low=env_action_space.low,  # type: ignore
            high=env_action_space.high,  # type: ignore
            shape=env_action_space.shape,  # type: ignore
            dtype=env_action_space.dtype,  # type: ignore
            device=device,
        )

        if self.is_vector_env:
            self.observation_space = batch_space_custom(
                self.observation_space, n=self.num_envs
            )
            self.action_space = batch_torchbox_space(self.action_space, n=self.num_envs)

    def reset(self):
        obs = self.env.reset()  # type: ignore
        tensor_obs = self._numpy_to_torch(obs=obs)
        return tensor_obs

    def step(self, action: torch.Tensor):
        if action.dim() == 1:
            action = action.unsqueeze(dim=0)

        B, _ = action.size()
        # tensor size: [num_env, 1]
        if action.is_cuda:
            action = action.detach().cpu()

        action = [a.flatten() for a in np.vsplit(action.numpy(), B)]  # type: ignore
        # num_env

        # HACK: round to 7 decimal places to avoid precision inconsistencies between cpu and gpu
        # TODO: Understand why this hack works here and not in the other wrappers
        action = [np.around(a, 7) for a in action]  # type: ignore

        if self.is_vector_env:
            action = tuple(action)  # type: ignore
        else:
            action = action[0]

        obs, reward, done, info = self.env.step(action)  # type: ignore
        tensor_obs = self._numpy_to_torch(obs=obs)
        if self.is_vector_env:
            return tensor_obs, reward.tolist(), done.tolist(), info
        return tensor_obs, reward, done, info

    def _numpy_to_torch(self, obs: np.ndarray) -> torch.Tensor:
        tensor_obs = torch.from_numpy(obs)
        tensor_obs = tensor_obs.to(self.device)
        tensor_obs = tensor_obs.type(torch.float32)
        if tensor_obs.dim() == 1:
            tensor_obs = tensor_obs.unsqueeze(dim=0)
        return tensor_obs
