import numpy as np

from gym import spaces
from pettingzoo.utils import wrappers
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.test import (
    api_test,
    performance_benchmark,
    test_save_obs,
    max_cycles_test,
)

from expground.logger import log
from expground.types import Tuple
import os
import pickle as pkl


class PayoffType:
    RANDOM_SYMMETRIC = 0
    RANDOM = 1
    CYCLIC = 2


def payoff_matrix_generate(num_player: int, payoff_type: int, dim: int = 3, rng=None):
    # Force load from local file
    fpath = os.join(__path__, "AlphaStar.pkl")
    with open(fpath, "rb") as f:
        mm = pkl.load(f)
    matrix = [mm.copy() if i % 2 == 0 else mm.T.copy() for i in range(num_player)]
    return matrix


class Matrix(AECEnv):
    """
    Matrix environment shapes a normal form game in a cubic matrix. That is, all players share the same action dim.
    Payoff table could be heterogeneous between agents. Optional including `Random Symmetric`, `Random` and `Cyclic`.
    """

    metadata = {"render.modes": ["human"], "name": "matrix"}

    def __init__(
        self,
        num_players: int = 2,
        payoff_type: int = PayoffType.RANDOM,
        dim: int = 888,
        max_cycles: int = 1,
    ):
        super().__init__()

        # Fixed value from pkl file.
        num_players = 2
        payoff_type = PayoffType.RANDOM_SYMMETRIC
        dim = 888

        self.possible_agents = ["player_{}".format(i) for i in range(num_players)]
        self.agents = self.possible_agents[:]
        log.debug("matrix game generated with {} players".format(num_players))
        self.agent_name_mapping = dict(
            zip(self.possible_agents, list(range(num_players)))
        )

        self.payoff_type = payoff_type
        self.dim = dim
        self.max_cycles = max_cycles
        self.NONE_STATE = dim
        self.observation_spaces = {
            agent: spaces.Discrete(dim + 1) for agent in self.possible_agents
        }
        self.action_spaces = {
            agent: spaces.Discrete(dim) for agent in self.possible_agents
        }

        # agents share the same full policy set / action set.
        # different from `self.num_agents`, `self.num_players` is fixed.
        self.num_players = num_players
        self.num_moves = 0
        self.full_policy_set = tuple(range(dim))

        self.payoff_rng = np.random.RandomState()

    def seed(self, seed: int = None):
        self._seed = seed

    def observe(self, agent):
        observation = np.array(self.observations[agent])
        return observation

    def state(self) -> Tuple[int]:
        return tuple(self._state[i] for i in self.agents)

    def step(self, action: int):
        if self.dones[self.agent_selection]:
            return self._was_done_step(action)

        self._cumulative_rewards[self.agent_selection] = 0.0

        self._state[self.agent_selection] = action

        # collect reward if it is the last agent to act
        if self._agent_selector.is_last():
            self.num_moves += 1
            for agent in self.agents:
                self.rewards[agent] = self.payoff_matrix[agent][self.state()]
            self.dones = {
                agent: self.num_moves >= self.max_cycles for agent in self.agents
            }

            # observe the current state
            for agent in self.agents:
                self.observations[agent] = self.state()[self.agent_name_mapping[agent]]
        else:
            if self._agent_selector.is_first():
                # reset other agents state to NONE
                for agent in self.agents:
                    if agent is not self.agent_selection:
                        self._state[agent] = self.NONE_STATE
                self._clear_rewards()

        self.agent_selection = self._agent_selector.next()
        self._accumulate_rewards()

    def seed(self, seed=None):
        super().seed(seed=seed)
        self.payoff_rng = np.random.RandomState(seed)

    def reset(self):
        self.agents = self.possible_agents[:]
        self.payoff_matrix = dict(
            zip(
                self.agents,
                payoff_matrix_generate(
                    self.num_players,
                    self.payoff_type,
                    self.dim,
                    rng=self.payoff_rng,
                ),
            )
        )
        self.rewards = {agent: 0.0 for agent in self.agents}
        self._cumulative_rewards = {agent: 0.0 for agent in self.agents}
        self.dones = {agent: False for agent in self.agents}
        self.infos = {agent: {} for agent in self.agents}
        self.num_moves = 0

        self._state = {agent: self.NONE_STATE for agent in self.agents}
        self.observations = {agent: self.NONE_STATE for agent in self.agents}

        self._agent_selector = agent_selector(self.agents)
        self.agent_selection = self._agent_selector.next()

    def render(self, mode="human"):
        if len(self.agents) == self.num_players:
            string = ()

    def close(self):
        log.info("Matrix game closed")


def env(**kwargs):
    env = Matrix(**kwargs["scenario_config"])
    env = wrappers.CaptureStdoutWrapper(env)
    # env = wrappers.TerminateIllegalWrapper(env, illegal_reward=-1)
    env = wrappers.AssertOutOfBoundsWrapper(env)
    env = wrappers.OrderEnforcingWrapper(env)
    return env


if __name__ == "__main__":
    _env = env()
    log.info("====== API test =====")
    api_test(_env, num_cycles=20, verbose_progress=True)
    log.info("")
    log.info("===== Performance Benchmark Test =====")
    performance_benchmark(_env)
    log.info("")
    log.info("===== Save Observation Test =====")
    test_save_obs(_env)
    log.info("")
    # log.info("===== Max Cycles Test =====")
    # max_cycles_test(_env)
