import numpy as np
from gymnasium import spaces
from ray.rllib.examples.env.coin_game_non_vectorized_env import CoinGame


class CoinGameWrapper(CoinGame):
    def __init__(self, config=None):
        if config is None:
            config = {}
        if 'players_ids' in config:
            self.NUM_AGENTS = len(config['players_ids'])
        self.n_agents = self.NUM_AGENTS
        super().__init__(config)
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(self.n_agents, *self.observation_space.shape), dtype="uint8"
        )
        self.action_space = spaces.Tuple([spaces.Discrete(4) for _ in self.players_ids])

    def reset(self, *args, **kwargs):
        self.SW = 0.
        observations, _ = super().reset(*args, **kwargs)
        return self._process_observations(observations), self.get_info()

    def step(self, actions: list):
        observations, rewards, done, _, _ = super().step(self._process_actions(actions))
        observations = self._process_observations(observations)
        rewards = self._process_rewards(rewards)
        done = self._process_done(done)
        self.SW += np.sum(rewards).item()
        info = self.get_info()
        return observations, rewards, done, done, info

    def _process_actions(self, actions):
        return {pid: a for pid, a in zip(self.players_ids, actions)}

    def _process_observations(self, obs):
        return np.array([obs[pid] for pid in self.players_ids])

    def _process_rewards(self, rewards):
        other_coin_factor = .2
        rewards = [rewards[pid] for pid in self.players_ids]
        rewards_neg = [r < 0 for r in rewards]
        if any(rewards_neg):
            if all(rewards_neg):
                rewards = [other_coin_factor] * len(rewards)  # both agents picked up each other's coins
            else:
                rewards = [max(r*other_coin_factor, 0) for r in rewards]  # one of the agents picked up a coin of another
        return rewards

    def _process_done(self, done):
        return done[self.players_ids[0]]

    def get_info(self):
        return {'social_welfare': self.SW}

    def close(self):
        pass
