import numpy as np
from open_spiel.python.algorithms.psro_v2 import optimization_oracle
from open_spiel.python.algorithms.psro_v2 import utils

def update_episodes_per_oracles(episodes_per_oracle, played_policies_indexes):
    """Updates the current episode count per policy.
    Args:
      episodes_per_oracle: List of list of number of episodes played per policy.
        One list per player.
      played_policies_indexes: List with structure (player_index, policy_index) of
        played policies whose count needs updating.
    Returns:
      Updated count.
    """
    for player_index, policy_index in played_policies_indexes:
        episodes_per_oracle[player_index][policy_index] += 1
    return episodes_per_oracle


def freeze_all(policies_per_player):
    """Freezes all policies within policy_per_player.

    Args:
      policies_per_player: List of list of number of policies.
    """
    for policies in policies_per_player:
        for pol in policies:
            pol.freeze()


def random_count_weighted_choice(count_weight):
    """Returns a randomly sampled index i with P ~ 1 / (count_weight[i] + 1).

    Allows random sampling to prioritize indexes that haven't been sampled as many
    times as others.

    Args:
      count_weight: A list of counts to sample an index from.

    Returns:
      Randomly-sampled index.
    """
    indexes = list(range(len(count_weight)))
    p = np.array([1 / (weight + 1) for weight in count_weight])
    p /= np.sum(p)
    chosen_index = np.random.choice(indexes, p=p)
    return chosen_index


class RLOracle(optimization_oracle.AbstractOracle):
    """Oracle handling Approximate Best Responses computation."""

    def __init__(self,
                 env,
                 best_response_class,
                 best_response_kwargs,
                 number_training_episodes=1e3,
                 self_play_proportion=0.0,
                 **kwargs):
        """Init function for the RLOracle.

        Args:
          env: rl_environment instance.
          best_response_class: class of the best response.
          best_response_kwargs: kwargs of the best response.
          number_training_episodes: (Minimal) number of training episodes to run
            each best response through. May be higher for some policies.
          self_play_proportion: Float, between 0 and 1. Defines the probability that
            a non-currently-training player will actually play (one of) its
            currently training strategy (Which will be trained as well).
          **kwargs: kwargs
        """
        self._env = env

        self._best_response_class = best_response_class
        self._best_response_kwargs = best_response_kwargs

        self._self_play_proportion = self_play_proportion
        self._number_training_episodes = number_training_episodes

        super(RLOracle, self).__init__(**kwargs)

    def sample_episode(self, _, agents, is_evaluation=False):
        time_step = self._env.reset()
        cumulative_rewards = 0.0

        while not time_step.last():
            if time_step.is_simultaneous_move():
                action_list = []
                for agent in agents:
                    output = agent.step(time_step, is_evaluation=is_evaluation)
                    action_list.append(output.action)
                time_step = self._env.step(action_list)
                cumulative_rewards += np.array(time_step.rewards)
            else:
                player_id = time_step.observations["current_player"]

                # is_evaluation is a boolean that, when False, lets policies train. The
                # setting of PSRO requires that all policies be static aside from those
                # being trained by the oracle. is_evaluation could be used to prevent
                # policies from training, yet we have opted for adding frozen attributes
                # that prevents policies from training, for all values of is_evaluation.
                # Since all policies returned by the oracle are frozen before being
                # returned, only currently-trained policies can effectively learn.
                agent_output = agents[player_id].step(
                    time_step, is_evaluation=is_evaluation)
                action_list = [agent_output.action]
                time_step = self._env.step(action_list)
                cumulative_rewards += np.array(time_step.rewards)

        if not is_evaluation:
            for agent in agents:
                agent.step(time_step)

        return cumulative_rewards

    def _has_terminated(self, episodes_per_oracle):
        # The oracle has terminated when all policies have at least trained for
        # self._number_training_episodes. Given the stochastic nature of our
        # training, some policies may have more training episodes than that value.
        return np.all(
            episodes_per_oracle.reshape(-1) > self._number_training_episodes)

    def sample_policies_for_episode(self, new_policies, training_parameters,
                                    episodes_per_oracle, strategy_sampler):
        """Randomly samples a set of policies to run during the next episode.

        Note : sampling is biased to select players & strategies that haven't
        trained as much as the others.

        Args:
          new_policies: The currently training policies, list of list, one per
            player.
          training_parameters: List of list of training parameters dictionaries, one
            list per player, one dictionary per training policy.
          episodes_per_oracle: List of list of integers, computing the number of
            episodes trained on by each policy. Used to weight the strategy
            sampling.
          strategy_sampler: Sampling function that samples a joint strategy given
            probabilities.

        Returns:
          Sampled list of policies (One policy per player), index of currently
          training policies in the list.
        """
        num_players = len(training_parameters)

        # Prioritizing players that haven't had as much training as the others.
        episodes_per_player = [sum(episodes) for episodes in episodes_per_oracle]
        chosen_player = random_count_weighted_choice(episodes_per_player)

        # Uniformly choose among the sampled player.
        agent_chosen_ind = np.random.randint(
            0, len(training_parameters[chosen_player]))
        agent_chosen_dict = training_parameters[chosen_player][agent_chosen_ind]
        new_policy = new_policies[chosen_player][agent_chosen_ind]

        # Sample other players' policies.
        total_policies = agent_chosen_dict["total_policies"]
        probabilities_of_playing_policies = agent_chosen_dict[
            "probabilities_of_playing_policies"]
        episode_policies = strategy_sampler(total_policies,
                                            probabilities_of_playing_policies)

        live_agents_player_index = [(chosen_player, agent_chosen_ind)]

        for player in range(num_players):
            if player == chosen_player:
                episode_policies[player] = new_policy
                assert not new_policy.is_frozen()
            else:
                # Sample a bernoulli with parameter 'self_play_proportion' to determine
                # whether we do self-play with 'player'.
                if np.random.binomial(1, self._self_play_proportion):
                    # If we are indeed doing self-play on that round, sample among the
                    # trained strategies of current_player, with priority given to less-
                    # selected agents.
                    agent_index = random_count_weighted_choice(
                        episodes_per_oracle[player])
                    self_play_agent = new_policies[player][agent_index]
                    episode_policies[player] = self_play_agent
                    live_agents_player_index.append((player, agent_index))
                else:
                    assert episode_policies[player].is_frozen()

        return episode_policies, live_agents_player_index

    def _rollout(self, game, agents, **oracle_specific_execution_kwargs):
        self.sample_episode(None, agents, is_evaluation=False)

    def generate_new_policies(self, training_parameters):
        """Generates new policies to be trained into best responses.

        Args:
          training_parameters: list of list of training parameter dictionaries, one
            list per player.

        Returns:
          List of list of the new policies, following the same structure as
          training_parameters.
        """
        new_policies = []
        for player in range(len(training_parameters)):
            player_parameters = training_parameters[player]
            new_pols = []

            # TODO: change best response class to mbrl
            for param in player_parameters:
                current_pol = param["policy"]
                if isinstance(current_pol, self._best_response_class):
                    new_pol = current_pol.copy_with_noise(self._kwargs.get("sigma", 0.0))
                else:
                    new_pol = self._best_response_class(self._env, player,
                                                        **self._best_response_kwargs)
                    new_pol.unfreeze()
                new_pols.append(new_pol)
            new_policies.append(new_pols)
        return new_policies

    def __call__(self,
                 game,
                 training_parameters,
                 strategy_sampler=utils.sample_strategy,
                 **oracle_specific_execution_kwargs):
        episodes_per_oracle = [[0
                                for _ in range(len(player_params))]
                               for player_params in training_parameters]
        episodes_per_oracle = np.array(episodes_per_oracle)

        new_policies = self.generate_new_policies(training_parameters)

        while not self._has_terminated(episodes_per_oracle):
            agents, indexes = self.sample_policies_for_episode(
                new_policies, training_parameters, episodes_per_oracle,
                strategy_sampler)
            self._rollout(game, agents, **oracle_specific_execution_kwargs)
            episodes_per_oracle = update_episodes_per_oracles(episodes_per_oracle, indexes)

        freeze_all(new_policies)
        return new_policies
