import logging
import itertools

from open_spiel.python.policy import (
    Policy as OSPolicy,
    TabularPolicy,
    joint_action_probabilities_aux,
)
from open_spiel.python.algorithms import policy_aggregator
from open_spiel.python.algorithms.exploitability import nash_conv

try:
    import pyspiel
except ImportError as e:
    logging.warning(
        "Cannot import open spiel, if you wanna run meta game experiment, please install it before that."
    )

import numpy as np

from expground.types import Dict, AgentID, PolicyID, Sequence, LambdaType, List
from expground.algorithms.base_policy import Policy


# inference: https://github.com/JBLanier/pipeline-psro


class PolicyFromCallable(OSPolicy):
    """For backwards-compatibility reasons, create a policy from a callable."""

    def __init__(self, game, callable_policy: Policy):
        # When creating a Policy from a pyspiel_policy, we do not have the game.
        if game is None:
            all_players = None
        else:
            all_players = list(range(game.num_players()))
        super(PolicyFromCallable, self).__init__(game, all_players)
        self._callable_policy = callable_policy

    def action_probabilities(self, state, player_id=None):
        return dict(self._callable_policy(state))


class NFSPPolicies(OSPolicy):
    """Joint policy to be evaluated."""

    def __init__(self, game, nfsp_policies: List[TabularPolicy]):
        player_ids = [i for i in range(game.num_players())]
        assert len(player_ids) == 2
        super(NFSPPolicies, self).__init__(game, player_ids)
        self.dynamics_type = game.get_type().dynamics
        self._policies = nfsp_policies
        self._obs = {"info_state": [None, None], "legal_actions": [None, None]}

    def action_probabilities(self, state, player_id=None):
        if player_id is None:
            # we need to compute joint actions here
            cur_player = state.current_player()
            if cur_player == -2:
                # it is a simultaneous node
                actions_per_player, probs_per_player = joint_action_probabilities_aux(
                    state, self
                )
                dim = self.game.num_distinct_actions()
                res = [
                    (actions[0] * dim + actions[1], np.prod(probs))
                    for actions, probs in zip(
                        itertools.product(*actions_per_player),
                        itertools.product(*probs_per_player),
                    )
                ]
                # import pdb; pdb.set_trace()
                prob_dict = dict(res)
            else:
                prob_dict = self._policies[cur_player].action_probabilities(
                    state, cur_player
                )
            # import pdb; pdb.set_trace()
            # print("np playeffe computed prob_Dict:", cur_player, prob_dict)
        else:
            prob_dict = self._policies[player_id].action_probabilities(state, player_id)
            # print("computed prob_Dict:", prob_dict)
        return prob_dict


def build_open_spiel_policy(
    policy: Policy, game, observation_adapter, player_id: int = None
) -> OSPolicy:
    def policy_callable(state: pyspiel.State):
        if player_id is not None:
            valid_actions = state.legal_actions_mask(player_id)
            legal_actions_list = state.legal_actions(player_id)
            info_state_vector = state.information_state_tensor(player_id)
        else:
            valid_actions = state.legal_actions_mask()
            legal_actions_list = state.legal_actions()
            info_state_vector = state.information_state_tensor()
        obs_info = {
            "action_mask": np.asarray(valid_actions, dtype=np.float32),
            "evaluate": True,
        }
        observation = observation_adapter(state, policy._observation_space, player_id)
        obs = policy.preprocessor.transform(observation)
        # malib_policy.eval()
        _, action_probs, _ = policy.compute_action(observation=obs, **obs_info)

        legal_action_probs = []
        for idx in range(len(valid_actions)):
            if valid_actions[idx] == 1.0:
                legal_action_probs.append(action_probs[idx])
        return {
            action_name: action_prob
            for action_name, action_prob in zip(legal_actions_list, legal_action_probs)
        }

    # pack current policy as an OSPolicy
    return PolicyFromCallable(game=game, callable_policy=policy_callable)


def tabular_policy_from_weighted_policies(
    game, policy_iterable, weights, player_id: int = None
) -> TabularPolicy:
    """Pack multiple RL policies as one policy

    Args:
        game ([type]): A game instance,
        policy_iterable ([type]): Iteratable policy func for one agent.
        weights ([type]): A dict of sub policy weights.

    Returns:
        TabularPolicy: A tabular policy instance.
    """

    assert np.isclose(1.0, sum(weights), rtol=1e-4)

    empty_tabular_policy = TabularPolicy(game)
    empty_tabular_policy.action_probability_array = np.zeros_like(
        empty_tabular_policy.action_probability_array
    )

    for (
        policy,
        weight,
    ) in zip(policy_iterable, weights):
        if weight == 0.0:
            continue
        for state_index, state in enumerate(empty_tabular_policy.states):
            if player_id is not None and state.current_player() is not player_id:
                continue
            old_action_probabilities = empty_tabular_policy.action_probabilities(
                state, player_id=player_id
            )
            add_action_probabilities = policy.action_probabilities(state)
            infostate_policy = [
                old_action_probabilities.get(action, 0.0)
                + add_action_probabilities.get(action, 0.0) * weight
                for action in range(game.num_distinct_actions())
            ]
            empty_tabular_policy.action_probability_array[
                state_index, :
            ] = infostate_policy

    # print("dddd", empty_tabular_policy.action_probability_array)

    # check that all action probs pers state add up to one in the newly created policy
    for state_index, state in enumerate(empty_tabular_policy.states):
        if player_id is not None and state.current_player() is not player_id:
            # player_id = None
            continue
        action_probabilities = empty_tabular_policy.action_probabilities(
            state,  # player_id=player_id
        )
        infostate_policy = [
            action_probabilities.get(action, 0.0)
            for action in range(game.num_distinct_actions())
        ]

        assert np.isclose(
            1.0, sum(infostate_policy), rtol=1e-4
        ), "INFOSTATE POLICY: {} {}".format(infostate_policy, sum(infostate_policy))

    return empty_tabular_policy


from expground.envs import open_spiel_adapters


def measure_exploitability(
    game_name_or_matrix_desc: str,
    populations: Dict[AgentID, Dict[PolicyID, Policy]],
    policy_mixture_dict: Dict[AgentID, Dict[PolicyID, float]],
):
    if isinstance(game_name_or_matrix_desc, str):
        game = pyspiel.load_game(game_name_or_matrix_desc)
    else:
        game = game_name_or_matrix_desc

    def policy_iterable(agent: AgentID):
        """Return a generator of OSPolicy.

        Args:
            agent (AgentID): Environment agent id, should map to

        Yields:
            Iterator[OSPolicy]: A OSPolicy instance.
        """
        for pid in policy_mixture_dict[agent]:
            single_open_spiel_policy = build_open_spiel_policy(
                populations[agent][pid],
                game,
                open_spiel_adapters.observation_adapter,
                # agent id must like: xxx_0
                int(agent.split("_")[-1]),
                # None
                # if game.get_type().dynamics == pyspiel.GameType.Dynamics.SEQUENTIAL
                # else int(agent.split("_")[-1]),
            )
            yield single_open_spiel_policy

    policies: List[TabularPolicy] = [
        tabular_policy_from_weighted_policies(
            game,
            policy_iterable(aid),
            policy_mixture_dict[aid].values(),
            int(aid.split("_")[-1])
            # None
            # if game.get_type().dynamics == pyspiel.GameType.Dynamics.SEQUENTIAL
            # else int(aid.split("_")[-1]),
        )
        for aid in populations
    ]
    open_spiel_policy = NFSPPolicies(game, policies)

    # return exploitability(game=open_spiel_game, policy=open_spiel_policy)
    return nash_conv(game=game, policy=open_spiel_policy, return_only_nash_conv=False)


# def spiel_calculator(
#     game_name_or_matrix_desc: str,
#     populations: Dict[AgentID, Dict[PolicyID, Policy]],
#     policy_mixture_dict: Dict[AgentID, Dict[PolicyID, float]],
# ):

#     if isinstance(game_name_or_matrix_desc, str):
#         game = pyspiel.load_game(game_name_or_matrix_desc)
#     else:
#         game = game_name_or_matrix_desc

#     aggregator = policy_aggregator.PolicyAggregator(game)

#     # policies is a list of list
#     policies = []
#     agents = list(populations.keys())
#     pids = {a: [] for a in agents}
#     meta_probabilities = []
#     for aid in agents:
#         tmp = []
#         probs = []
#         _policies = populations[aid]
#         mixture = policy_mixture_dict[aid]
#         for pid in pids:
#             tmp.append(_policies)
#             probs.append(mixture[pid])
#         policies.append(tmp)
#         probs = np.asarray(probs)
#         meta_probabilities.append(probs)

#     aggr_policies = aggregator.aggregate(range(len(populations)), policies, meta_probabilities)
#     value = nash_conv(game, aggr_policies, return_only_nash_conv=False)
#     return value
