from typing import Dict, Optional, Union, Callable

import torch
import vmas
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.envs.libs.vmas import VmasWrapper


class HallucVmasWrapper(VmasWrapper):

    def __init__(
        self,
        env = None,  # noqa
        **kwargs,
    ):
        super().__init__(env, **kwargs)

    def _step(
        self,
        tensordict: TensorDictBase,
        total_env_steps: int,
        total_halluc_env_steps: int,
        iteration: int,
    ) -> TensorDictBase:
        
        # Step the environment
        obs, rews, dones, infos = self._env.step(tensordict, total_env_steps, total_halluc_env_steps, iteration)
        dones = self.read_done(dones)
        agent_tds = []
        for i in range(self.n_agents):
            agent_obs = self.read_obs(obs[i])
            agent_rew = self.read_reward(rews[i])
            agent_info = self.read_info(infos[i])

            agent_td = TensorDict(
                source={
                    "observation": agent_obs,
                    "reward": agent_rew,
                },
                batch_size=self.batch_size,
                device=self.device,
            )
            if agent_info is not None:
                agent_td.set("info", agent_info)
            agent_tds.append(agent_td)

        agent_tds = torch.stack(agent_tds, dim=1)
        if not self.het_specs:
            agent_tds = agent_tds.to_tensordict()
        tensordict_out = TensorDict(
            source={"agents": agent_tds, "done": dones, "terminated": dones.clone()},
            batch_size=self.batch_size,
            device=self.device,
        )
        return tensordict_out


class HallucVmasEnv(HallucVmasWrapper):

    def __init__(
        self,
        scenario: Union[str, "vmas.simulator.scenario.BaseScenario"],  # noqa
        num_envs: int,
        continuous_actions: bool = True,
        max_steps: Optional[int] = None,
        categorical_actions: bool = True,
        seed: Optional[int] = None,
        **kwargs,
    ):
        kwargs["scenario"] = scenario
        kwargs["num_envs"] = num_envs
        kwargs["continuous_actions"] = continuous_actions
        kwargs["max_steps"] = max_steps
        kwargs["seed"] = seed
        kwargs["categorical_actions"] = categorical_actions
        super().__init__(**kwargs)

    def _check_kwargs(self, kwargs: Dict):
        if "scenario" not in kwargs:
            raise TypeError("Could not find environment key 'scenario' in kwargs.")
        if "num_envs" not in kwargs:
            raise TypeError("Could not find environment key 'num_envs' in kwargs.")

    def _build_env(
        self,
        scenario: Union[str, "vmas.simulator.scenario.BaseScenario"],  # noqa
        num_envs: int,
        continuous_actions: bool,
        max_steps: Optional[int],
        seed: Optional[int],
        **scenario_kwargs,
    ) -> "vmas.simulator.environment.halluc_environment.HallucEnvironment":
        vmas = self.lib

        self.scenario_name = scenario
        from_pixels = scenario_kwargs.pop("from_pixels", False)
        pixels_only = scenario_kwargs.pop("pixels_only", False)

        return super()._build_env(
            env=vmas.make_env(
                scenario=scenario,
                num_envs=num_envs,
                device=self.device,
                continuous_actions=continuous_actions,
                max_steps=max_steps,
                seed=seed,
                wrapper=None,
                hallucinate=True,
                **scenario_kwargs,
            ),
            pixels_only=pixels_only,
            from_pixels=from_pixels,
        )

    def __repr__(self):
        return f"{super().__repr__()} (scenario={self.scenario_name})"
