import datetime
import gym
import numpy as np
import uuid
# import torch #[todo]

class TimeLimit(gym.Wrapper):
    def __init__(self, env, duration):
        super().__init__(env)
        self._duration = duration
        self._step = None

    def step(self, action):
        assert self._step is not None, "Must reset environment."
        obs, reward, done, info = self.env.step(action)
        self._step += 1
        if self._step >= self._duration:
            done = True
            if "discount" not in info:
                info["discount"] = np.array(1.0).astype(np.float32)
            self._step = None
        return obs, reward, done, info

    def reset(self):
        self._step = 0
        return self.env.reset()


class NormalizeActions(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self._mask = np.logical_and(
            np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
        )
        self._low = np.where(self._mask, env.action_space.low, -1)
        self._high = np.where(self._mask, env.action_space.high, 1)
        low = np.where(self._mask, -np.ones_like(self._low), self._low)
        high = np.where(self._mask, np.ones_like(self._low), self._high)
        self.action_space = gym.spaces.Box(low, high, dtype=np.float32)

    def step(self, action):
        original = (action + 1) / 2 * (self._high - self._low) + self._low
        original = np.where(self._mask, original, action)
        return self.env.step(original)


class OneHotAction(gym.Wrapper):
    def __init__(self, env):
        assert isinstance(env.action_space, gym.spaces.Discrete)
        super().__init__(env)
        self._random = np.random.RandomState()
        shape = (self.env.action_space.n,)
        space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
        space.discrete = True
        self.action_space = space

    def step(self, action):
        index = np.argmax(action).astype(int)
        reference = np.zeros_like(action)
        reference[index] = 1
        if not np.allclose(reference, action):
            raise ValueError(f"Invalid one-hot action:\n{action}")
        return self.env.step(index)

    def reset(self):
        return self.env.reset()

    def _sample_action(self):
        actions = self.env.action_space.n
        index = self._random.randint(0, actions)
        reference = np.zeros(actions, dtype=np.float32)
        reference[index] = 1.0
        return reference


class RewardObs(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        spaces = self.env.observation_space.spaces
        if "obs_reward" not in spaces:
            spaces["obs_reward"] = gym.spaces.Box(
                -np.inf, np.inf, shape=(1,), dtype=np.float32
            )
        self.observation_space = gym.spaces.Dict(spaces)

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if "obs_reward" not in obs:
            obs["obs_reward"] = np.array([reward], dtype=np.float32)
        return obs, reward, done, info

    def reset(self):
        obs = self.env.reset()
        if "obs_reward" not in obs:
            obs["obs_reward"] = np.array([0.0], dtype=np.float32)
        return obs


class SelectAction(gym.Wrapper):
    def __init__(self, env, key):
        super().__init__(env)
        self._key = key

    def step(self, action):
        return self.env.step(action[self._key])


class UUID(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
        self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"

    def reset(self):
        timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
        self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
        return self.env.reset()

#[todo] start
class MultitaskWrapper(gym.Wrapper):
    """
    Wrapper for multi-task environments(state in obs is chosen).
    output obs and input action are all in max dim.
    """

    def __init__(self, env, action_dim, obs_dim=None, obs_type="state"):
        super().__init__(env)
        self.obs_type = obs_type
        self._obs_shape = (obs_dim,)
        self.ori_observation_space = self.env.observation_space
        self.ori_action_space = self.env.action_space
        self.action_dim = action_dim
        # self.observation_space = gym.spaces.Box(
        #     low=-np.inf, high=np.inf, shape=self._obs_shape, dtype=np.float32
        # )
        # self.action_space = gym.spaces.Box(
        #     low=-1, high=1, shape=(action_dim,), dtype=np.float32
        # )
        self.observation_space = self._observation_space()
        self.action_space = self._action_space()
        print(f"Wrap Multitask")


    def _observation_space(self):
        # 复制原始空间的所有配置（避免修改原环境的空间）
        modified_spaces = dict(self.ori_observation_space.spaces)
        if self.obs_type == "state":
            # 重新创建state的Box空间（保持dtype/极值不变，仅改形状）
            modified_spaces["state"] = gym.spaces.Box(
                low=-np.inf, high=np.inf, shape=self._obs_shape, dtype=np.float32
            )

        return gym.spaces.Dict(modified_spaces)


    def _action_space(self):
        if not hasattr(self.ori_action_space, "discrete"):
            modified_spaces = gym.spaces.Box(
                low=self.ori_action_space.low[0], high=self.ori_action_space.high[0], shape=(self.action_dim,), dtype=np.float32
            )
        else:
            modified_spaces = gym.spaces.Discrete(self.action_dim)
            if not hasattr(modified_spaces, "discrete"):
                modified_spaces.discrete = True
        return modified_spaces


    # @property
    # def task(self):
    #     return self._task
    #
    # @property
    # def task_idx(self):
    #     return self._task_idx

    # @property
    # def _env(self):
    #     return self.envs[self.task_idx]

    def rand_act(self):
        return self.action_space.sample().astype(np.float32)
        # return torch.from_numpy(self.action_space.sample().astype(np.float32))

    def _pad_obs(self, obs):
        if obs.shape != self._obs_shape:
            # obs = torch.cat((obs, torch.zeros(self._obs_shape[0] - obs.shape[0], dtype=obs.dtype, device=obs.device)))
            obs = np.concatenate([obs, np.zeros(self._obs_shape[0] - obs.shape[0], dtype=obs.dtype)])
        return obs

    def reset(self, task_idx=-1):
        obs = self.env.reset()
        if self.obs_type == "state":
            obs["state"] = self._pad_obs(obs["state"])
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action[:self.env.action_space.shape[0]])
        if self.obs_type == "state":
            obs["state"] = self._pad_obs(obs["state"])
        return obs, reward, done, info

#[todo] end