import os
import os.path as osp
import time
import logging
import pytest
import yaml
import ray
import numpy as np

from expground import settings
from expground.settings import BASE_DIR
from expground.logger import Log
from expground.types import AgentID, PolicyConfig, Dict, List, Tuple
from expground.utils.preprocessor import get_preprocessor
from expground.learner.psro.bi_psro import BPSROLearner
from expground import cli


pytestmark = pytest.mark.bipsro

ray.init(local_mode=True)


@pytest.fixture
def yamls() -> List[str]:
    files = []
    for f in os.listdir(osp.join(BASE_DIR, "examples/configs/bpsro")):
        if f.endswith(".yaml"):
            files.append(osp.join(BASE_DIR, "examples/configs/bpsro", f))
    logging.info("load yamls: {}".format(files))
    return files


@pytest.fixture
def bpsro_disable_ray(yamls) -> List[Tuple[BPSROLearner, Dict, Dict]]:
    instances = []
    for _config in yamls:
        logging.info("parse yaml: {}".format(_config))
        with open(_config, "r") as f:
            raw_yaml = yaml.safe_load(f)
        env_desc, env_lib = cli.parse_env_config(raw_yaml["env_config"])
        learner_cls, other_params = cli.parse_learner_config(raw_yaml["learner_config"])
        assert (
            learner_cls == BPSROLearner
        ), "learner_cls should be {}, while: {} detected".format(
            BPSROLearner.__name__, learner_cls.__name__
        )
        rollout_config = cli.parse_rollout_config(raw_yaml["rollout_config"])
        trainining_config = cli.parse_training_config(raw_yaml)

        # load sampler_config from environment lib
        env_config = env_desc["config"]
        action_spaces = env_config["action_spaces"]
        observation_spaces = env_config["observation_spaces"]
        preprocessors = {
            aid: get_preprocessor(observation_space)(observation_space)
            for aid, observation_space in observation_spaces.items()
        }

        policy_cls = (
            cli.load_class_from_str(
                "expground.algorithms", raw_yaml["algorithm"]["policy"]
            )
            if raw_yaml.get("algorithm")
            else None
        )
        loss_func = (
            cli.load_class_from_str(
                "expground.algorithms", raw_yaml["algorithm"]["loss"]
            )
            if raw_yaml.get("algorithm")
            else None
        )

        # lambda for multi-agent cases
        sampler_config = lambda aid: env_lib.basic_sampler_config(
            observation_spaces[aid],
            action_spaces[aid],
            preprocessors[aid],
            capacity=raw_yaml["sampler_config"]["params"]["capacity"],
            learning_starts=raw_yaml["sampler_config"]["params"].get(
                "learning_starts", -1
            ),
        )

        Log.info("sampler_config: {}".format(action_spaces))

        policy_config = PolicyConfig(
            policy=policy_cls,
            mapping=lambda agent: agent,
            observation_space=lambda k: env_config["observation_spaces"][k],
            action_space=lambda k: env_config["action_spaces"][k],
            custom_config=raw_yaml.get("custom_config", {}),
            model_config=raw_yaml.get("model_config", {}),
        )

        algorithm = (
            raw_yaml["algorithm"]["name"] if raw_yaml.get("algorithm") else "default"
        )

        general_stop_conditions = {
            "stop_conditions": raw_yaml["learner_config"]["stop_conditions"]
        }
        inner_stop_conditions = raw_yaml["learner_config"].get(
            "inner_stop_conditions", None
        )

        exp_prefix = raw_yaml.get("exp_prefix", None)

        learner = learner_cls(
            experiment=settings.EXP_NAME_FORMAT_LAMBDA(
                args=(
                    exp_prefix,
                    f"{env_desc['config']['env_id']}_{algorithm}",
                    str(time.time()),
                )
            ),
            policy_config=policy_config,
            env_description=env_desc,
            rollout_config=rollout_config,
            training_config=trainining_config,
            loss_func=loss_func,
            # learning_mode=
            **other_params,
        )

        if inner_stop_conditions:
            general_stop_conditions.update(
                {"inner_stop_conditions": inner_stop_conditions}
            )
        instances.append((learner, sampler_config, general_stop_conditions))
    return instances


def test_add_policy(bpsro_disable_ray: List[Tuple[BPSROLearner, Dict, Dict]]):
    for (
        learner,
        _,
        _,
    ) in bpsro_disable_ray:
        learner.add_policy(n_support=2)
        for sub_learner in learner._agent_learners.values():
            if learner.ray_mode:
                ray.get(sub_learner.set_ego_policy_fixed.remote())
                ego_fixed_policy_ids = ray.get(
                    sub_learner.get_ego_fixed_policy_ids.remote()
                )
            else:
                sub_learner.set_ego_policy_fixed()
                ego_fixed_policy_ids = sub_learner.get_ego_fixed_policy_ids()
            assert all([len(v) == 3 for v in ego_fixed_policy_ids.values()]), (
                ego_fixed_policy_ids,
            )


def test_partial_matrix_generation(
    bpsro_disable_ray: List[Tuple[BPSROLearner, Dict, Dict]]
):
    for (
        learner,
        _,
        _,
    ) in bpsro_disable_ray:
        learner.add_policy(n_support=2)
        for sub_learner in learner._agent_learners.values():
            ray.get(sub_learner.set_ego_policy_fixed.remote())
    for learner, _, _ in bpsro_disable_ray:
        # ordered full_shape
        full_shape_dict = learner.num_support
        full_shape = [full_shape_dict[agent] for agent in learner.agents]
        assert all(np.asarray(full_shape) > 0), (full_shape,)
        evaluation_results = {
            agent: np.zeros(full_shape[:i] + full_shape[i + 1 :])
            for i, agent in enumerate(learner.agents)
        }
        res = learner._gen_partial_matrix(evaluation_results)
        for i, (agent, _partial_matrix) in enumerate(res.items()):
            predict_shape = tuple(full_shape[:i] + [1] + full_shape[i + 1 :])
            assert _partial_matrix.shape == predict_shape, (
                _partial_matrix.shape,
                predict_shape,
                agent,
                i,
            )


def test_equilibrium_computation(
    bpsro_disable_ray: List[Tuple[BPSROLearner, Dict, Dict]]
):
    for (
        learner,
        _,
        _,
    ) in bpsro_disable_ray:
        learner.add_policy(n_support=2)
        for sub_learner in learner._agent_learners.values():
            ray.get(sub_learner.set_ego_policy_fixed.remote())

        learner._sync_up()

    # update payoff manager
    for learner, _, _ in bpsro_disable_ray:
        # ordered full_shape
        full_shape_dict = learner.num_support
        full_shape = [full_shape_dict[agent] for agent in learner.agents]
        assert all(np.asarray(full_shape) > 0), (full_shape,)
        evaluation_results = {
            agent: np.zeros(full_shape[:i] + full_shape[i + 1 :])
            for i, agent in enumerate(learner.agents)
        }
        fixed_policy_ids = {}
        for sub_learner in learner._agent_learners.values():
            if learner.ray_mode:
                fixed_policy_ids.update(
                    ray.get(sub_learner.get_ego_fixed_policy_ids.remote())
                )
            else:
                fixed_policy_ids.update(sub_learner.get_ego_fixed_policy_ids())
        equilibrium = learner._compute_equilibrium(
            fixed_policy_ids,
            evaluation_results,
            ori_payoff_matrix=learner.payoff_matrix.get_sub_matrix(fixed_policy_ids),
        )
        assert equilibrium is not None, (equilibrium, fixed_policy_ids)


def test_evaluate_comb(bpsro_disable_ray: List[Tuple[BPSROLearner, Dict, Dict]]):
    # add fixed policies
    n = 1
    for (
        learner,
        _,
        _,
    ) in bpsro_disable_ray[1:2]:
        learner.add_policy(n_support=n)
        for sub_learner in learner._agent_learners.values():
            ray.get(sub_learner.set_ego_policy_fixed.remote())

        learner._sync_up()

    for (
        learner,
        _,
        _,
    ) in bpsro_disable_ray[1:2]:
        learner.add_policy(n_support=1)

    for learner, _, _ in bpsro_disable_ray[1:2]:
        active_pids = {}
        fixed_pids = {}

        for sub_learner in learner._agent_learners.values():
            active_pids.update(ray.get(sub_learner.get_ego_active_policy_ids.remote()))
            fixed_pids.update(ray.get(sub_learner.get_ego_fixed_policy_ids.remote()))

        learner.payoff_matrix.expand(active_pids)
        active_pids = {k: v[0] for k, v in active_pids.items()}

        res = learner._evaluate_comb(active_pids, fixed_pids)
        for agent, _agent_res in res.items():
            assert len(_agent_res) == (n + 1) ** (len(learner.agents) - 1), _agent_res


def test_br_optimization(bpsro_disable_ray: List[Tuple[BPSROLearner, Dict, Dict]]):
    for learner, sampler_config, general_conditions in bpsro_disable_ray[1:2]:

        # ========= compute equilibrium based on evaluation ========
        full_shape_dict = learner.num_support
        full_shape = [full_shape_dict[agent] for agent in learner.agents]
        assert all(np.asarray(full_shape) > 0), (full_shape,)
        evaluation_results = {
            agent: np.zeros(full_shape[:i] + full_shape[i + 1 :])
            for i, agent in enumerate(learner.agents)
        }
        fixed_policy_ids = {}
        for sub_learner in learner._agent_learners.values():
            if learner.ray_mode:
                fixed_policy_ids.update(
                    ray.get(sub_learner.get_ego_fixed_policy_ids.remote())
                )
            else:
                fixed_policy_ids.update(sub_learner.get_ego_fixed_policy_ids())
        equilibrium = learner._compute_equilibrium(
            fixed_policy_ids,
            evaluation_results,
            ori_payoff_matrix=learner.payoff_matrix.get_sub_matrix(fixed_policy_ids),
        )
        # =========================================================

        # ============== generate active policy ids ==============
        ego_pids = {}
        learner.add_policy(n_support=1)
        for sub_learner in learner._agent_learners.values():
            if learner.ray_mode:
                ego_pids.update(ray.get(sub_learner.get_ego_active_policy_ids.remote()))
            else:
                ego_pids.update(sub_learner.get_ego_active_policy_ids())
        assert all([len(v) == 1 for v in ego_pids.values()]), (ego_pids,)
        learner.payoff_matrix.expand(ego_pids)
        ego_pids = {agent: v[0] for agent, v in ego_pids.items()}
        # ========================================================

        general_conditions["inner_stop_conditions"]["max_episode"] = 100
        eval_result: Dict[AgentID, np.ndarray] = learner._optimize_best_response(
            ego_pids,
            fixed_policy_ids,
            equilibrium=equilibrium,
            sampler_config=sampler_config,
            inner_stop_conditions=general_conditions["inner_stop_conditions"],
        )
        # expected shape of eval result is a mapping from agent to numpy array like
        for i, (agent, _array) in enumerate(eval_result.items()):
            assert _array.shape == tuple(full_shape[:i] + full_shape[i + 1 :]), (
                i,
                agent,
                full_shape,
                _array,
            )


def test_learn(bpsro_disable_ray: List[Tuple[BPSROLearner, Dict, Dict]]):
    for learner, sampler_config, general_conditions in bpsro_disable_ray[1:2]:

        # # let us add a new fixed policy for each agents
        # learner.add_policy(n_support=1)
        # learner._sync_up()

        # disable configured stop conditions
        general_conditions["stop_conditions"] = None
        learner.max_bilevel_step = 2
        learner.learn(sampler_config, **general_conditions)
