import random
from typing import List

import numpy as np
from gym import spaces
from pettingzoo.utils import wrappers
from pettingzoo.utils.env import AECEnv

from open_spiel.python.rl_environment import Environment as OPEN_SPIEL_ENV, TimeStep
from expground.envs import open_spiel_adapters


class Poker(AECEnv):
    metadata = {"render.modes": ["human"]}

    def __init__(self, **kwargs):
        super(Poker, self).__init__()
        scenario_config = kwargs.get("scenario_config")
        self._open_spiel_env = OPEN_SPIEL_ENV(
            game=kwargs["env_id"], players=scenario_config["players"]
        )
        self._kwargs = kwargs
        self._reward_scale = kwargs["scenario_config"].get("reward_scale", 1.0)

        self.possible_agents = [
            f"player_{i}" for i in range(self._open_spiel_env.num_players)
        ]
        self.agents = self.possible_agents[:]

        # obs_shape = self._open_spiel_env.observation_spec()["info_state"]
        # num_actions = self._open_spiel_env.action_spec()["num_actions"]
        self.observation_spaces = dict(
            zip(
                self.possible_agents,
                [
                    open_spiel_adapters.ObservationSpace(
                        self._open_spiel_env.observation_spec()
                    )
                    for _ in range(self.num_agents)
                ],
            )
        )

        self.action_spaces = dict(
            zip(
                self.possible_agents,
                [
                    open_spiel_adapters.ActionSpace(self._open_spiel_env.action_spec())
                    for _ in range(self.num_agents)
                ],
            )
        )

        self._cur_time_step: TimeStep = None
        self._fixed_player = scenario_config.get("fixed_player", False)
        self._scenario_players = scenario_config["players"]
        self._player_map = None

    def seed(self, seed=None):
        # warning: nothing will be done since I don't know
        #  how to set the random seed in the underlying game.
        self._open_spiel_env = OPEN_SPIEL_ENV(
            game=self._kwargs["env_id"],
            players=self._kwargs["scenario_config"]["players"],
        )

    def _scale_rewards(self, reward, scale=1.0):
        return [e * scale for e in reward]

    def _int_to_name(self, ind):
        ind = self._player_map(ind)
        return self.possible_agents[ind]

    def _name_to_int(self, name):
        return self._player_map(self.possible_agents.index(name))

    def _convert_to_dict(self, data_list: List):
        agents = [
            self.possible_agents[self._player_map(i)] for i in range(self.num_agents)
        ]
        return dict(zip(agents, data_list))

    def observe(self, agent):
        observation = open_spiel_adapters.observation_adapter(
            self._cur_time_step.observations,
            self.observation_spaces[agent],
            player_id=agent,
        )
        return observation

    def step(self, action):
        if self.dones[self.agent_selection]:
            return self._was_done_step(action)
        self._cur_time_step = self._open_spiel_env.step([action])

        if self._cur_time_step.last():
            self.rewards = self._convert_to_dict(
                self._scale_rewards(self._cur_time_step.rewards, self._reward_scale)
            )
            self.infos[self.agent_selection]["legal_moves"] = []
            self.next_legal_moves = []
            self.dones = self._convert_to_dict([True for _ in range(self.num_agents)])
            next_player = self._int_to_name(1 - self._name_to_int(self.agent_selection))
        else:
            next_player = self._int_to_name(self._cur_time_step.current_player())
            self.next_legal_moves = self._cur_time_step.observations["legal_actions"][
                self._cur_time_step.current_player()
            ]
        self._cumulative_rewards[self.agent_selection] = 0
        self.agent_selection = next_player
        self._last_obs = self.observe(self.agent_selection)
        self.infos[self.agent_selection]["legal_moves"] = self.next_legal_moves
        self._accumulate_rewards()
        self._dones_step_first()

    def reset(self):
        if self._fixed_player:
            self._player_map = lambda p: p
        else:
            self._player_map = random.choice([lambda p: p, lambda p: 1 - p])

        self._cur_time_step = self._open_spiel_env.reset()
        self.agents = self.possible_agents[:]
        self.agent_selection = self._int_to_name(self._cur_time_step.current_player())
        self.rewards = self._convert_to_dict([0.0 for _ in range(self.num_agents)])
        self._cumulative_rewards = self._convert_to_dict(
            [0.0 for _ in range(self.num_agents)]
        )
        self.dones = self._convert_to_dict([False for _ in range(self.num_agents)])
        self.infos = self._convert_to_dict(
            [
                {"legal_moves": _lm}
                for _lm in self._cur_time_step.observations["legal_actions"]
            ]
        )
        self.next_legal_moves = list(
            sorted(self.infos[self.agent_selection]["legal_moves"])
        )
        self._last_obs = np.array(
            self._cur_time_step.observations["info_state"][
                self._name_to_int(self.agent_selection)
            ],
            dtype=np.int8,
        )

    def render(self, mode="human"):
        raise NotImplementedError()

    def close(self):
        pass


def env(**kwargs):
    env = Poker(**kwargs)
    env = wrappers.CaptureStdoutWrapper(env)
    env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)
    env = wrappers.AssertOutOfBoundsWrapper(env)
    env = wrappers.OrderEnforcingWrapper(env)
    env.agent_to_group = lambda x: x
    env.seed = lambda x: None
    return env
