import numpy as np
import torch
from gymnasium import spaces
from typing import List


# ===========================================================
# ① Environment for multi-peak matrix game
# ===========================================================
class TeamGameMultiOptEnv:
    """
    Multi-agent matrix game with multiple reward peaks.
    - Background rewards ∈ [-10, 0]
    - Randomly sampled n_peaks, with rank-0 as the global optimum
    - Observation = [stage] + last_actions + last_reward; shape = [n_agents + 2]
    """

    def __init__(
        self,
        seed: int,
        num_agents: int,
        num_actions: int,
        n_peaks: int = 5,
        peak_base: float = 15.0,
        num_stages: int = 10,
        reward_decay: float = 1.0,
    ):
        if n_peaks < 1:
            raise ValueError("n_peaks must be ≥ 1")
        self.rng = np.random.RandomState(seed)

        self.num_agents = num_agents
        self.num_actions = num_actions
        self.num_stages = num_stages
        self.reward_decay = reward_decay

        shape = [num_actions] * num_agents
        rewards = self.rng.randint(-10, 1, size=shape).astype(np.float32)

        all_indices = list(np.ndindex(*shape))
        peak_indices = self.rng.choice(len(all_indices), size=n_peaks, replace=False)
        peak_indices = [all_indices[i] for i in peak_indices]

        for rank, idx in enumerate(peak_indices):
            rewards[idx] = peak_base + (n_peaks - rank - 1)

        self.global_rewards = rewards

        self.stage = 0
        self.prev_actions = [num_actions] * num_agents
        self.prev_reward = 0.0

        self.action_space = spaces.MultiDiscrete([num_actions] * num_agents)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(num_agents + 2,), dtype=np.float32
        )

    def reset(self):
        self.stage = 0
        self.prev_actions = [self.num_actions] * self.num_agents
        self.prev_reward = 0.0
        return self._obs()

    def step(self, actions):
        if len(actions) != self.num_agents:
            raise ValueError
        if any(a < 0 or a >= self.num_actions for a in actions):
            raise ValueError

        r = float(self.global_rewards[tuple(actions)])
        r *= self.reward_decay ** self.stage

        self.prev_actions = actions
        self.prev_reward = r
        self.stage += 1
        done = self.stage >= self.num_stages
        return self._obs(), r, done, {}

    def _obs(self):
        return np.asarray(
            [self.stage] + self.prev_actions + [self.prev_reward], dtype=np.float32
        )

    def render(self, mode="human"):
        print(f"Stage {self.stage}, Actions {self.prev_actions}, Reward {self.prev_reward:.1f}")


# ===========================================================
# ② VMAS-compatible vectorized wrapper
# ===========================================================
class TeamGameVectorEnv:
    def __init__(
        self,
        num_envs: int,
        num_agents: int,
        num_actions: int,
        n_peaks: int,
        peak_base: float,
        num_stages: int,
        reward_decay: float,
        device: str,
        seed: int | None,
        share_reward: bool,
    ):
        self.num_envs = num_envs
        self.num_agents = num_agents
        self.device = torch.device(device)
        self.share_reward = share_reward

        self.envs: List[TeamGameMultiOptEnv] = [
            TeamGameMultiOptEnv(
                seed=seed,
                num_agents=num_agents,
                num_actions=num_actions,
                n_peaks=n_peaks,
                peak_base=peak_base,
                num_stages=num_stages,
                reward_decay=reward_decay,
            )
            for _ in range(num_envs)
        ]

        self.agents = [f"agent_{i}" for i in range(num_agents)]
        self.observation_space = tuple(
            spaces.Box(
                low=-np.inf, high=np.inf, shape=(num_agents + 2,), dtype=np.float32
            )
            for _ in range(num_agents)
        )
        self.action_space = tuple(spaces.Discrete(num_actions) for _ in range(num_agents))

    def reset(self):
        obs_rows = [
            torch.as_tensor(env.reset(), device=self.device).unsqueeze(0)
            for env in self.envs
        ]
        obs_mat = torch.cat(obs_rows, dim=0)
        return [obs_mat] * self.num_agents

    def step(self, action_list):
        action_mat = torch.stack(action_list, dim=1).cpu().numpy()

        obs_rows, rew_rows, done_rows = [], [], []
        for env_idx, env in enumerate(self.envs):
            o, r, d, _ = env.step(action_mat[env_idx].tolist())
            obs_rows.append(torch.as_tensor(o, device=self.device).unsqueeze(0))
            rew_rows.append(r)
            done_rows.append(d)

        obs_mat = torch.cat(obs_rows, dim=0)
        rew_tensor = torch.as_tensor(rew_rows, dtype=torch.float32, device=self.device)
        done_tensor = torch.as_tensor(done_rows, dtype=torch.bool, device=self.device)

        return [obs_mat] * self.num_agents, [rew_tensor] * self.num_agents, done_tensor, {}

    def get_random_action(self, agent_name: str):
        idx = int(agent_name.split("_")[-1])
        return torch.randint(
            0, self.action_space[idx].n, (self.num_envs,), device=self.device
        )


# ===========================================================
# ③ Factory function
# ===========================================================
def make_env(
    scenario: str = "matrix_game_multiopt",
    num_envs: int = 8,
    n_agents: int = 2,
    num_actions: int = 3,
    n_peaks: int = 5,
    peak_base: float = 15.0,
    num_stages: int = 10,
    reward_decay: float = 1.0,
    device: str = "cpu",
    share_reward: str | bool = "True",
    seed: int | None = None,
    **_,
):
    if scenario != "matrix_game_multiopt":
        raise ValueError("Only 'matrix_game_multiopt' supported.")
    return TeamGameVectorEnv(
        num_envs=num_envs,
        num_agents=n_agents,
        num_actions=num_actions,
        n_peaks=n_peaks,
        peak_base=peak_base,
        num_stages=num_stages,
        reward_decay=reward_decay,
        device=device,
        seed=seed,
        share_reward=str(share_reward).lower() == "true",
    )


# ===========================================================
# ④ Quick test
# ===========================================================
if __name__ == "__main__":
    env = make_env(
        scenario="matrix_game_multiopt",
        num_envs=3,
        n_agents=2,
        num_actions=3,
        n_peaks=5,
        peak_base=20,
        num_stages=6,
        seed=0,
    )

    obs = env.reset()
    print("reset shapes:", [o.shape for o in obs])
    for t in range(4):
        actions = [env.get_random_action(a) for a in env.agents]
        obs, rews, dones, _ = env.step(actions)
        print(f"t={t+1} obs_shape={obs[0].shape} reward={rews[0].tolist()} done={dones.tolist()}")
