import numpy as np
import torch
from gymnasium import spaces
from env_matrix_single import TeamGameEnv
from typing import List


class TeamGameVectorEnv:
    """
    VMAS-style vectorized wrapper around `TeamGameEnv`.
    This mimics VMAS API so that existing training code can be reused.
    """
    def __init__(
        self,
        num_envs: int,
        num_agents: int,
        num_actions: int,
        num_stages: int = 25,
        reward_decay: float = 1.0,
        device: str = "cpu",
        seed: int | None = None,
        share_reward: bool = True,
    ):
        self.num_envs = num_envs
        self.num_agents = num_agents
        self.device = torch.device(device)
        self.share_reward = share_reward

        rng = np.random.RandomState(seed)
        self.envs: List[TeamGameEnv] = [
            TeamGameEnv(
                seed=seed,
                num_agents=num_agents,
                num_actions=num_actions,
                num_stages=num_stages,
                reward_decay=reward_decay,
            )
            for _ in range(num_envs)
        ]

        self.agents = [f"agent_{i}" for i in range(num_agents)]
        self.observation_space = tuple(
            spaces.Box(
                low=0.0,
                high=float(num_actions),
                shape=(num_agents + 2,),
                dtype=np.float32,
            )
            for _ in range(num_agents)
        )
        self.action_space = tuple(spaces.Discrete(num_actions) for _ in range(num_agents))

    def reset(self):
        """
        Returns:
            list[Tensor]: length = num_agents, each shape = [num_envs, obs_dim]
        """
        obs_stack = [
            torch.as_tensor(env.reset(), device=self.device).unsqueeze(0)
            for env in self.envs
        ]
        obs_mat = torch.cat(obs_stack, dim=0)
        obs_list = [obs_mat] * self.num_agents
        return obs_list

    def step(self, action_list):
        """
        Args:
            action_list (list[Tensor]): shape [num_envs] per agent

        Returns:
            next_obs_list: list[Tensor]
            reward_list: list[Tensor]
            dones_tensor: Tensor[bool]
            info: dict
        """
        action_mat = torch.stack(action_list, dim=1).cpu().numpy()

        next_obs_rows, reward_rows, done_flags = [], [], []

        for env_idx, env in enumerate(self.envs):
            actions_this_env = action_mat[env_idx].tolist()
            next_state, reward, done, _ = env.step(actions_this_env)

            next_obs_rows.append(torch.as_tensor(next_state, device=self.device))
            reward_rows.append(reward)
            done_flags.append(done)

        next_obs_mat = torch.stack(next_obs_rows, dim=0)
        rewards_tensor = torch.as_tensor(reward_rows, dtype=torch.float32, device=self.device)
        dones_tensor = torch.as_tensor(done_flags, dtype=torch.bool, device=self.device)

        reward_list = [rewards_tensor] * self.num_agents
        next_obs_list = [next_obs_mat] * self.num_agents
        info = {}

        return next_obs_list, reward_list, dones_tensor, info

    def get_random_action(self, agent_name):
        """Returns [num_envs] Tensor of random actions for the specified agent"""
        idx = int(agent_name.split("_")[-1])
        rand_actions = torch.randint(
            low=0,
            high=self.action_space[idx].n,
            size=(self.num_envs,),
            device=self.device,
        )
        return rand_actions


def make_env(
    scenario: str = "matrix_game",
    num_envs: int = 16,
    n_agents: int = 4,
    num_actions: int = 5,
    num_stages: int = 25,
    reward_decay: float = 1.0,
    device: str = "cpu",
    share_reward: str | bool = "True",
    seed: int | None = None,
    **_ignored,
):
    if scenario != "matrix_game":
        raise ValueError(f"Unknown scenario '{scenario}'. Only 'matrix_game' is supported.")
    share_reward = True if str(share_reward).lower() == "true" else False
    return TeamGameVectorEnv(
        num_envs=num_envs,
        num_agents=n_agents,
        num_actions=num_actions,
        num_stages=num_stages,
        reward_decay=reward_decay,
        device=device,
        seed=seed,
        share_reward=share_reward,
    )

# --------------- Optional test block ---------------
# if __name__ == "__main__":
#     env = make_env()
#     obs = env.reset()
#     for step in range(10):
#         actions = [env.get_random_action(agent) for agent in env.agents]
#         obs, rewards, dones, info = env.step(actions)
#         if torch.any(dones):
#             break
