import random
import pytest
import numpy as np

from expground.types import RolloutConfig, TrainingConfig
from expground.envs.matrix import creator, env_desc_gen
from expground.envs.matrix.env import PayoffType
from expground.algorithms.base_policy import pack_action_to_policy
from expground.utils import rollout
from expground.learner.do import DOLearner


@pytest.mark.parametrize("ray_mode,seed", [(False, 1)])
class TestDO:
    @pytest.fixture(autouse=True)
    def prepare(self, ray_mode: bool, seed: int):
        self.env_desc = env_desc_gen(
            env_id="random_symmetric",
            scenario_config={
                "num_players": 2,
                "payoff_type": PayoffType.RANDOM_SYMMETRIC,
                "dim": 3,
                "max_cycles": 2,
            },
        )
        self.rollout_config = RolloutConfig(
            caller="sequential", fragment_length=100, max_step=10, num_simulation=20
        )
        self.training_config = TrainingConfig(
            trainer_cls=rollout.sequential_rollout,
            hyper_params={
                "batch_size": 32,
                "total_timesteps": 10000,
                "learning_rate": 0.1,
            },
        )
        self.learner = DOLearner(
            experiment="test_do_learner",
            env_description=self.env_desc,
            rollout_config=self.rollout_config,
            training_config=self.training_config,
            ray_mode=ray_mode,
            seed=seed,
            evaluation_worker_num=2,
            agent_mapping=lambda agent: agent,
        )

        _env = self.env_desc["creator"](**self.env_desc["config"])
        _env.seed(seed)
        _env.reset()
        self.target_env_payoff_matrix = _env.unwrapped.payoff_matrix

    def _init_subset(self):
        brs = {}
        for agent in self.learner.agents:
            _agent = self.learner.agent_mapping(agent)
            sset = self.learner.full_policy_set[_agent]
            # if len(self.sub_policy_set[_agent]) == 0:
            # self.sub_policy_set[_agent] = [random.choice(sset)]
            brs[agent] = [random.choice(sset)]  # [self.sub_policy_set[_agent][0]]
            # mapping populations
            self.learner.populations[_agent]["policy_0"] = pack_action_to_policy(
                self.learner.full_policy_set[_agent],
                self.learner.agent_interfaces[_agent].observation_space,
                is_fixed=True,
                distribution=np.array(
                    [float(e == brs[agent][0]) for e in sset], dtype=np.float32
                ),
            )

        self.learner.payoff_manager.expand(
            {agent: ["policy_0"] for agent in self.learner.agents}
        )

    def test_env_payoff_matrix_checking(self):
        assert (
            type(self.target_env_payoff_matrix)
            == type(self.learner.env_payoff_matrix)
            == dict
        )
        assert (
            self.target_env_payoff_matrix.keys()
            == self.learner.env_payoff_matrix.keys()
        ), (self.target_env_payoff_matrix.keys(), self.learner.env_payoff_matrix.keys())
        for k, v in self.target_env_payoff_matrix.items():
            assert np.array_equal(v, self.learner.env_payoff_matrix[k])

    def test_run_simulation(self):
        self._init_subset()
        results = self.learner.run_simulation()
        assert len(results) == 1

        policy_mapping, feedback = results[0][0], results[0][1]
        assert set(policy_mapping.keys()) == set(self.learner.agents), (
            policy_mapping.keys(),
            self.learner.agents,
        )
        assert policy_mapping[self.learner.agents[0]] == "policy_0"

    def test_payoff_manager_update(self):
        self._init_subset()
        results = self.learner.run_simulation()
        self.learner.payoff_manager.update_payoff_and_simulation_status(results)

        policy_sets = {
            agent: list(
                self.learner.populations[self.learner.agent_mapping(agent)].keys()
            )
            for agent in self.learner.agents
        }
        utilities = self.learner.payoff_manager.get_sub_matrix(policy_sets)

        equilibrium = self.learner.meta_solver.solve(utilities)
        meta_strategies = {
            k: dict(zip(policy_sets[k], v)) for k, v in equilibrium.items()
        }
        nash_conv = self.learner.compute_nash_conv(meta_strategies)

        print("nash_conv:", nash_conv)
