import numpy as np
import rlcard
from gymnasium import spaces

from pettingzoo import AECEnv


class RLCardBase(AECEnv):
    def __init__(self, name: str, num_players: int, obs_shape: tuple):
        super().__init__()
        self.name = name
        self.num_players = num_players
        config = {
            "allow_step_back": False,
            "seed": None,
            "game_num_players": num_players,
        }

        self.env = rlcard.make(name, config)
        self.screen = None
        if not hasattr(self, "agents"):
            self.agents = [f"player_{i}" for i in range(num_players)]
        self.possible_agents = self.agents[:]

        dtype = self.env.reset()[0]["obs"].dtype
        if dtype == np.dtype(np.int64):
            self._dtype = np.dtype(np.int8)
        elif dtype == np.dtype(np.float64):
            self._dtype = np.dtype(np.float32)
        else:
            self._dtype = dtype

        self.observation_spaces = self._convert_to_dict(
            [
                spaces.Dict(
                    {
                        "observation": spaces.Box(
                            low=0.0, high=1.0, shape=obs_shape, dtype=self._dtype
                        ),
                        "action_mask": spaces.Box(
                            low=0, high=1, shape=(self.env.num_actions,), dtype=np.int8
                        ),
                    }
                )
                for _ in range(self.num_agents)
            ]
        )
        self.action_spaces = self._convert_to_dict(
            [spaces.Discrete(self.env.num_actions) for _ in range(self.num_agents)]
        )

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]

    def _seed(self, seed=None):
        config = {
            "allow_step_back": False,
            "seed": seed,
            "game_num_players": self.num_players,
        }
        self.env = rlcard.make(self.name, config)

    def _scale_rewards(self, reward):
        return reward

    def _int_to_name(self, ind):
        return self.possible_agents[ind]

    def _name_to_int(self, name):
        return self.possible_agents.index(name)

    def _convert_to_dict(self, list_of_list):
        return dict(zip(self.possible_agents, list_of_list))

    def observe(self, agent):
        obs = self.env.get_state(self._name_to_int(agent))
        observation = obs["obs"].astype(self._dtype)

        legal_moves = self.next_legal_moves
        action_mask = np.zeros(self.env.num_actions, "int8")
        for i in legal_moves:
            action_mask[i] = 1

        return {"observation": observation, "action_mask": action_mask}

    def step(self, action):
        if (
            self.terminations[self.agent_selection]
            or self.truncations[self.agent_selection]
        ):
            return self._was_dead_step(action)
        obs, next_player_id = self.env.step(action)
        next_player = self._int_to_name(next_player_id)
        self._last_obs = self.observe(self.agent_selection)
        if self.env.is_over():
            self.rewards = self._convert_to_dict(
                self._scale_rewards(self.env.get_payoffs())
            )
            self.next_legal_moves = []
            self.terminations = self._convert_to_dict(
                [True for _ in range(self.num_agents)]
            )
            self.truncations = self._convert_to_dict(
                [False for _ in range(self.num_agents)]
            )
        else:
            self.next_legal_moves = obs["legal_actions"]
        self._cumulative_rewards[self.agent_selection] = 0
        self.agent_selection = next_player
        self._accumulate_rewards()
        self._deads_step_first()

    def reset(self, seed=None, options=None):
        if seed is not None:
            self._seed(seed=seed)
        obs, player_id = self.env.reset()
        self.agents = self.possible_agents[:]
        self.agent_selection = self._int_to_name(player_id)
        self.rewards = self._convert_to_dict([0 for _ in range(self.num_agents)])
        self._cumulative_rewards = self._convert_to_dict(
            [0 for _ in range(self.num_agents)]
        )
        self.terminations = self._convert_to_dict(
            [False for _ in range(self.num_agents)]
        )
        self.truncations = self._convert_to_dict(
            [False for _ in range(self.num_agents)]
        )
        self.infos = self._convert_to_dict([{} for _ in range(self.num_agents)])
        self.next_legal_moves = list(sorted(obs["legal_actions"]))
        self._last_obs = obs["obs"]

    def render(self):
        raise NotImplementedError()

    def close(self):
        pass
