import gymnasium as gym
import numpy as np
from gymnasium import spaces


class TeamGameEnv(gym.Env):
    def __init__(self, seed, num_agents, num_actions, num_stages=10, reward_decay=1):
        super(TeamGameEnv, self).__init__()

        self.num_agents = num_agents
        self.num_actions = num_actions
        self.num_stages = num_stages
        self.reward_decay = reward_decay

        shape = [num_actions] * num_agents
        np.random.seed(seed)
        self.global_rewards = np.random.randint(-10, 21, size=shape)

        self.previous_actions = [0] * num_agents
        self.previous_reward = -200
        self.stage = 0
        self.current_state = None

        self.action_space = spaces.MultiDiscrete([num_actions] * num_agents)
        self.observation_space = spaces.Box(
            low=0,
            high=num_actions,
            shape=(num_agents + 2,),
            dtype=np.int32,
        )
        self.max_value = np.max(self.global_rewards) * num_stages

    def reset(self):
        self.previous_actions = [self.num_actions] * self.num_agents
        self.previous_reward = -11
        self.stage = 0
        self.current_state = np.array(
            [self.stage] + self.previous_actions + [self.previous_reward],
            dtype=np.float32
        )
        return self.current_state

    def step(self, actions):
        if len(actions) != self.num_agents:
            raise ValueError(f"Expected {self.num_agents} actions, but got {len(actions)}")
        if any(action < 0 or action >= self.num_actions for action in actions):
            raise ValueError(f"Actions must be in the range [0, {self.num_actions - 1}]")

        reward = self.global_rewards[tuple(actions)]

        if self.stage > 0:
            reward *= (self.reward_decay ** self.stage)

        self.previous_actions = actions
        self.previous_reward = reward
        self.stage += 1

        self.current_state = np.array(
            [self.stage] + self.previous_actions + [self.previous_reward],
            dtype=np.float32
        )

        done = self.stage >= self.num_stages
        return self.current_state, reward, done, {}

    def render(self, mode='human'):
        print(f"Stage: {self.stage}, Actions: {self.previous_actions}, Reward: {self.previous_reward}")
