import itertools
import numpy as np
from open_spiel.python import policy
from open_spiel.python.algorithms.psro_v2 import abstract_meta_trainer
from open_spiel.python.algorithms.psro_v2 import strategy_selectors
from open_spiel.python.algorithms.psro_v2 import utils

TRAIN_TARGET_SELECTORS = {"": None,
                          "rectified": strategy_selectors.rectified_selector}

class PSROSolver(abstract_meta_trainer.AbstractMetaTrainer):
    def __init__(self,
                 game,
                 oracle,
                 sims_per_entry,
                 initial_policies=None,
                 rectifier="",
                 training_strategy_selector=None,
                 meta_strategy_method="alpharank",
                 sample_from_marginals=False,
                 number_policies_selected=1,
                 n_noisy_copies=0,
                 alpha_noise=0.0,
                 beta_noise=0.0,
                 **kwargs):
        self._sims_per_entry = sims_per_entry
        print("Using {} sims per entry.".format(sims_per_entry))

        self._rectifier = TRAIN_TARGET_SELECTORS.get(
            rectifier, None)
        self._rectify_training = self._rectifier
        print("Rectifier : {}".format(rectifier))

        self._meta_strategy_probabilities = np.array([])
        self._non_marginalized_probabilities = np.array([])

        print("Perturbating oracle outputs : {}".format(n_noisy_copies > 0))
        self._n_noisy_copies = n_noisy_copies
        self._alpha_noise = alpha_noise
        self._beta_noise = beta_noise

        self._policies = []  # A list of size `num_players` of lists containing the strategies of each player.
        self._new_policies = []

        # Alpharank is a special case here, as it's not supported by the abstract
        # meta trainer api, so has to be passed as a function instead of a string.
        if not meta_strategy_method or meta_strategy_method == "alpharank":
            meta_strategy_method = utils.alpharank_strategy

        print("Sampling from marginals : {}".format(sample_from_marginals))
        self.sample_from_marginals = sample_from_marginals

        super(PSROSolver, self).__init__(
            game,
            oracle,
            initial_policies,
            meta_strategy_method,
            training_strategy_selector,
            number_policies_selected=number_policies_selected,
            **kwargs)

    def _initialize_policy(self, initial_policies):
        self._policies = [[] for k in range(self._num_players)]
        self._new_policies = [([initial_policies[k]] if initial_policies else
                               [policy.UniformRandomPolicy(self._game)])
                               for k in range(self._num_players)]

    def _initialize_game_state(self):
        effective_payoff_size = self._game_num_players
        self._meta_games = [np.array(utils.empty_list_generator(effective_payoff_size)) for _ in range(effective_payoff_size)]
        self.update_empirical_gamestate(seed=None)

    def iteration(self, seed=None):
        """Main trainer loop.
        Args:
          seed: Seed for random BR noise generation.
        """
        self._iterations += 1
        best_utility = self.update_agents()  # Generate new, Best Response agents via oracle.
        self.update_empirical_gamestate(seed=seed)  # Update gamestate matrix.
        self.update_meta_strategies()  # Compute meta strategy (e.g. Nash)
        return best_utility

    def get_joint_policy_ids(self):
        """Returns a list of integers enumerating all joint meta strategies."""
        return utils.get_strategy_profile_ids(self._meta_games)

    def get_joint_policies_from_id_list(self, selected_policy_ids):
        """Returns a list of joint policies from a list of integer IDs.
        Args:
          selected_policy_ids: A list of integer IDs corresponding to the
            meta-strategies, with duplicate entries allowed.
        Returns:
          selected_joint_policies: A list, with each element being a joint policy
            instance (i.e., a list of policies, one per player).
        """
        policies = self.get_policies()

        selected_joint_policies = utils.get_joint_policies_from_id_list(
            self._meta_games, policies, selected_policy_ids)
        return selected_joint_policies

    def update_meta_strategies(self):
        """Recomputes the current meta strategy of each player.

        Given new payoff tables, we call self._meta_strategy_method to update the
        meta-probabilities.
        """
        if self.symmetric_game:
            self._policies = self._policies * self._game_num_players

        self._meta_strategy_probabilities, self._non_marginalized_probabilities = \
            self._meta_strategy_method(solver=self, return_joint=True)

        if self.symmetric_game:
            self._policies = [self._policies[0]]
            self._meta_strategy_probabilities = [self._meta_strategy_probabilities[0]]

    def get_policies_and_strategies(self):
        """Returns current policy sampler, policies and meta-strategies of the game.

        If strategies are rectified, we automatically switch to returning joint
        strategies.

        Returns:
          sample_strategy: A strategy sampling function
          total_policies: A list of list of policies, one list per player.
          probabilities_of_playing_policies: the meta strategies, either joint or
            marginalized.
        """
        sample_strategy = utils.sample_strategy_marginal
        probabilities_of_playing_policies = self.get_meta_strategies()
        if self._rectify_training or not self.sample_from_marginals:
            sample_strategy = utils.sample_strategy_joint
            probabilities_of_playing_policies = self._non_marginalized_probabilities

        total_policies = self.get_policies()
        return sample_strategy, total_policies, probabilities_of_playing_policies

    def _restrict_target_training(self,
                                  current_player,
                                  ind,
                                  total_policies,
                                  probabilities_of_playing_policies,
                                  restrict_target_training_bool,
                                  epsilon=1e-12):
        """Rectifies training.

        Args:
          current_player: the current player.
          ind: Current strategy index of the player.
          total_policies: all policies available to all players.
          probabilities_of_playing_policies: meta strategies.
          restrict_target_training_bool: Boolean specifying whether to restrict
            training. If False, standard meta strategies are returned. Otherwise,
            restricted joint strategies are returned.
          epsilon: threshold below which we consider 0 sum of probabilities.

        Returns:
          Probabilities of playing each joint strategy (If rectifying) / probability
          of each player playing each strategy (Otherwise - marginal probabilities)
        """
        true_shape = tuple([len(a) for a in total_policies])
        if not restrict_target_training_bool:
            return probabilities_of_playing_policies
        else:
            kept_probas = self._rectifier(
                self, current_player, ind)
            # Ensure probabilities_of_playing_policies has same shape as kept_probas.
            probability = probabilities_of_playing_policies.reshape(true_shape)
            probability = probability * kept_probas
            prob_sum = np.sum(probability)

            # If the rectified probabilities are too low / 0, we play against the
            # non-rectified probabilities.
            if prob_sum <= epsilon:
                probability = probabilities_of_playing_policies
            else:
                probability /= prob_sum

            return probability

    def update_agents(self):
        """Updates policies for each player at the same time by calling the oracle.

        The resulting policies are appended to self._new_policies.
        """

        used_policies, used_indexes = self._training_strategy_selector(
            self, self._number_policies_selected)

        (sample_strategy,
         total_policies,
         probabilities_of_playing_policies) = self.get_policies_and_strategies()
        # Contains the training parameters of all trained oracles.
        # This is a list (Size num_players) of list (Size num_new_policies[player]),
        # each dict containing the needed information to train a new best response.
        training_parameters = [[] for _ in range(self._num_players)]

        for current_player in range(self._num_players):
            if self.sample_from_marginals:
                currently_used_policies = used_policies[current_player]
                current_indexes = used_indexes[current_player]
            else:
                currently_used_policies = [
                    joint_policy[current_player] for joint_policy in used_policies
                ]
                current_indexes = used_indexes[current_player]

            for i in range(len(currently_used_policies)):
                pol = currently_used_policies[i]
                ind = current_indexes[i]

                new_probabilities = self._restrict_target_training(
                    current_player, ind, total_policies,
                    probabilities_of_playing_policies,
                    self._rectify_training)

                new_parameter = {
                    "policy": pol,
                    "total_policies": total_policies,
                    "current_player": current_player,
                    "probabilities_of_playing_policies": new_probabilities
                }
                training_parameters[current_player].append(new_parameter)

        if self.symmetric_game:
            self._policies = self._game_num_players * self._policies
            self._num_players = self._game_num_players
            training_parameters = [training_parameters[0]]

        # List of List of new policies (One list per player)
        self._new_policies, best_utility = self._oracle(
            self._game,
            training_parameters,
            strategy_sampler=sample_strategy,
            using_joint_strategies=self._rectify_training or
                                   not self.sample_from_marginals)

        if self.symmetric_game:
            # In a symmetric game, only one population is kept. The below lines
            # therefore make PSRO consider only the first player during training,
            # since both players are identical.
            self._policies = [self._policies[0]]
            self._num_players = 1
        return best_utility

    def update_empirical_gamestate(self, seed=None):
        """Given new agents in _new_policies, update meta_games through simulations.

        Args:
          seed: Seed for environment generation.

        Returns:
          Meta game payoff matrix.
        """
        if seed is not None:
            np.random.seed(seed=seed)
        assert self._oracle is not None

        if self.symmetric_game:
            # Switch to considering the game as a symmetric game where players have
            # the same policies & new policies. This allows the empirical gamestate
            # update to function normally.
            self._policies = self._game_num_players * self._policies
            self._new_policies = self._game_num_players * self._new_policies
            self._num_players = self._game_num_players

        # Concatenate both lists.
        updated_policies = [
            self._policies[k] + self._new_policies[k]
            for k in range(self._num_players)
        ]

        # Each metagame will be (num_strategies)^self._num_players.
        # There are self._num_player metagames, one per player.
        total_number_policies = [
            len(updated_policies[k]) for k in range(self._num_players)
        ]
        number_older_policies = [
            len(self._policies[k]) for k in range(self._num_players)
        ]
        number_new_policies = [
            len(self._new_policies[k]) for k in range(self._num_players)
        ]

        # Initializing the matrix with nans to recognize unestimated states.
        meta_games = [
            np.full(tuple(total_number_policies), np.nan)
            for k in range(self._num_players)
        ]

        # Filling the matrix with already-known values.
        older_policies_slice = tuple(
            [slice(len(self._policies[k])) for k in range(self._num_players)])
        for k in range(self._num_players):
            meta_games[k][older_policies_slice] = self._meta_games[k]

        # Filling the matrix for newly added policies.
        for current_player in range(self._num_players):
            # Only iterate over new policies for current player ; compute on every
            # policy for the other players.
            range_iterators = [
                                  range(total_number_policies[k]) for k in range(current_player)
                              ] + [range(number_new_policies[current_player])] + [
                                  range(total_number_policies[k])
                                  for k in range(current_player + 1, self._num_players)
                              ]
            for current_index in itertools.product(*range_iterators):
                used_index = list(current_index)
                used_index[current_player] += number_older_policies[current_player]
                if np.isnan(meta_games[current_player][tuple(used_index)]):
                    estimated_policies = [
                                             updated_policies[k][current_index[k]]
                                             for k in range(current_player)
                                         ] + [
                                             self._new_policies[current_player][current_index[current_player]]
                                         ] + [
                                             updated_policies[k][current_index[k]]
                                             for k in range(current_player + 1, self._num_players)
                                         ]

                    if self.symmetric_game:
                        # TODO(author4): This update uses ~2**(n_players-1) * sims_per_entry
                        # samples to estimate each payoff table entry. This should be
                        # brought to sims_per_entry to coincide with expected behavior.
                        utility_estimates = self.sample_episodes(estimated_policies,
                                                                 self._sims_per_entry)

                        player_permutations = list(itertools.permutations(list(range(
                            self._num_players))))
                        for permutation in player_permutations:
                            used_tuple = tuple([used_index[i] for i in permutation])
                            for player in range(self._num_players):
                                if np.isnan(meta_games[player][used_tuple]):
                                    meta_games[player][used_tuple] = 0.0
                                meta_games[player][used_tuple] += utility_estimates[
                                                                      permutation[player]] / len(player_permutations)
                    else:
                        utility_estimates = self.sample_episodes(estimated_policies,
                                                                 self._sims_per_entry)
                        for k in range(self._num_players):
                            meta_games[k][tuple(used_index)] = utility_estimates[k]

        if self.symmetric_game:
            # Make PSRO consider that we only have one population again, as we
            # consider that we are in a symmetric game (No difference between players)
            self._policies = [self._policies[0]]
            self._new_policies = [self._new_policies[0]]
            updated_policies = [updated_policies[0]]
            self._num_players = 1

        self._meta_games = meta_games
        self._policies = updated_policies
        return meta_games

    def get_meta_game(self):
        """Returns the meta game matrix."""
        return self._meta_games

    @property
    def meta_games(self):
        return self._meta_games

    def get_policies(self):
        """Returns a list, each element being a list of each player's policies."""
        policies = self._policies
        if self.symmetric_game:
            # For compatibility reasons, return list of expected length.
            policies = self._game_num_players * self._policies
        return policies

    def get_and_update_non_marginalized_meta_strategies(self, update=True):
        """Returns the Nash Equilibrium distribution on meta game matrix."""
        if update:
            self.update_meta_strategies()
        return self._non_marginalized_probabilities

    def get_strategy_computation_and_selection_kwargs(self):
        return self._strategy_computation_and_selection_kwargs
