# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Compute the exploitability of a bot / strategy in a 2p sequential game.

This computes the value that a policy achieves against a worst-case opponent.
The policy applies to both player 1 and player 2, and hence we have a 2-player
symmetric zero-sum game, so the game value is zero for both players, and hence
value-vs-best-response is equal to exploitability.

We construct information sets, each consisting of a list of (state, probability)
pairs where probability is a counterfactual reach probability, i.e. the
probability that the state would be reached if the best responder (the current
player) played to reach it. This is the product of the probabilities of the
necessary chance events and opponent action choices required to reach the node.

These probabilities give us the correct weighting for possible states of the
world when considering our best response for a particular information set.

The values we calculate are values of being in the specific state. Unlike in a
CFR algorithm, they are not weighted by reach probabilities. These values
take into account the whole state, so they may depend on information which is
unknown to the best-responding player.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import numpy as np

from open_spiel.python.algorithms import best_response as pyspiel_best_response
import pyspiel


def _state_values(state, num_players, policy):
  """Value of a state for every player given a policy."""
  if state.is_terminal():
    return np.array(state.returns())
  else:
    p_action = (
        state.chance_outcomes() if state.is_chance_node() else
        policy.action_probabilities(state).items())
    return sum(prob * _state_values(state.child(action), num_players, policy)
               for action, prob in p_action)


def best_response(game, policy, player_id):
  """Returns information about the specified player's best response.

  Given a game and a policy for every player, computes for a single player their
  best unilateral strategy. Returns the value improvement that player would
  get, the action they should take in each information state, and the value
  of each state when following their unilateral policy.

  Args:
    game: An open_spiel game, e.g. kuhn_poker
    policy: A `policy.Policy` object. This policy should depend only on the
      information state available to the current player, but this is not
      enforced.
    player_id: The integer id of a player in the game for whom the best response
      will be computed.

  Returns:
    A dictionary of values, with keys:
      best_response_action: The best unilateral strategy for `player_id` as a
        map from infostatekey to action_id.
      best_response_state_value: The value obtained for `player_id` when
        unilaterally switching strategy, for each state.
      best_response_value: The value obtained for `player_id` when unilaterally
        switching strategy.
      info_sets: A dict of info sets, mapping info state key to a list of
        `(state, counterfactual_reach_prob)` pairs.
      nash_conv: `best_response_value - on_policy_value`
      on_policy_value: The value for `player_id` when all players follow the
        policy
      on_policy_values: The value for each player when all players follow the
        policy
  """
  root_state = game.new_initial_state()
  br = pyspiel_best_response.BestResponsePolicy(game, player_id, policy,
                                                root_state)
  on_policy_values = _state_values(root_state, game.num_players(), policy)
  best_response_value = br.value(root_state)

  # Get best response action for unvisited states
  for infostate in set(br.infosets) - set(br.cache_best_response_action):
    br.best_response_action(infostate)

  return {
      "best_response_action": br.cache_best_response_action,
      "best_response_state_value": br.cache_value,
      "best_response_value": best_response_value,
      "info_sets": br.infosets,
      "nash_conv": best_response_value - on_policy_values[player_id],
      "on_policy_value": on_policy_values[player_id],
      "on_policy_values": on_policy_values,
  }


def exploitability(game, policy):
  """Returns the exploitability of the policy in the game.

  This is implemented only for 2 players constant-sum games, and is equivalent
  to NashConv / num_players in that case. Prefer using `nash_conv`.

  Args:
    game: An open_spiel game, e.g. kuhn_poker
    policy: A `policy.Policy` object. This policy should depend only on the
      information state available to the current player, but this is not
      enforced.

  Returns:
    The value that this policy achieves when playing against the worst-case
    non-cheating opponent, averaged across both starting positions. It has a
    minimum of zero (assuming the supplied policy is non-cheating) and
    this bound is achievable in a 2p game.

  Raises:
    ValueError if the game is not a two-player constant-sum turn-based game.
  """
  if game.num_players() != 2:
    raise ValueError("Game must be a 2-player game")
  game_info = game.get_type()
  if game_info.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL:
    raise ValueError("The game must be turn-based, not {}".format(
        game_info.dynamics))
  if game_info.utility not in (pyspiel.GameType.Utility.ZERO_SUM,
                               pyspiel.GameType.Utility.CONSTANT_SUM):
    raise ValueError("The game must be constant- or zero-sum, not {}".format(
        game_info.utility))
  root_state = game.new_initial_state()
  nash_conv_value = (
      sum(
          pyspiel_best_response.CPPBestResponsePolicy(
              game, best_responder, policy).value(root_state)
          for best_responder in range(game.num_players())) - game.utility_sum())
  return nash_conv_value / game.num_players()


_NashConvReturn = collections.namedtuple("_NashConvReturn",
                                         ["nash_conv", "player_improvements"])


def nash_conv(game, policy, return_only_nash_conv=True, use_cpp_br=False):
  r"""Returns a measure of closeness to Nash for a policy in the game.

  See https://arxiv.org/pdf/1711.00832.pdf for the NashConv definition.

  Args:
    game: An open_spiel game, e.g. kuhn_poker
    policy: A `policy.Policy` object. This policy should depend only on the
      information state available to the current player, but this is not
      enforced.
    return_only_nash_conv: Whether to only return the NashConv value, or a
      namedtuple containing additional statistics. Prefer using `False`, as we
      hope to change the default to that value.
    use_cpp_br: if True, compute the best response in c++

  Returns:
    Returns a object with the following attributes:
    - player_improvements: A `[num_players]` numpy array of the improvement
      for players (i.e. value_player_p_versus_BR - value_player_p).
    - nash_conv: The sum over all players of the improvements in value that each
      player could obtain by unilaterally changing their strategy, i.e.
      sum(player_improvements).
  """
  root_state = game.new_initial_state()
  if use_cpp_br:
    best_response_values = np.array([
        pyspiel_best_response.CPPBestResponsePolicy(
            game, best_responder, policy).value(root_state)
        for best_responder in range(game.num_players())
    ])
  else:
    best_response_values = np.array([
        pyspiel_best_response.BestResponsePolicy(
            game, best_responder, policy).value(root_state)
        for best_responder in range(game.num_players())
    ])
  on_policy_values = _state_values(root_state, game.num_players(), policy)
  player_improvements = best_response_values - on_policy_values
  nash_conv_ = sum(player_improvements)
  if return_only_nash_conv:
    return nash_conv_
  else:
    return _NashConvReturn(
        nash_conv=nash_conv_, player_improvements=player_improvements)
