"""Games that let agents interact with bandits."""

from typing import Tuple, Union, Dict, Text

import numpy as np
from tqdm import tqdm

from agents import AbstractAgent, AbstractContextual, NestedAgent, TwoStage
from bandits import AbstractBandit, LinContextBandit, DatasetBandit


class StandardBanditGame:
  """In a standard bandit game a fixed set of agents interact with a MAB."""

  def __init__(self, bandit: AbstractBandit, *agents: AbstractAgent):
    """Initialize a standard bandit agent interaction.

    Args:
      bandit: The bandit with which the agents interact.
      agents: The agents that interact with the bandit.
    """
    self.bandit = bandit
    self.agents = agents

  def play(self, n_rounds: int = 1):
    """Let the agents interact with the bandit.

    Args:
      n_rounds: The number of interaction rounds.
    """
    for _ in tqdm(range(n_rounds)):
      rewards = self.bandit.sample().squeeze(0)
      for agent in self.agents:
        arm = agent.select_arm()
        agent.update(arm, rewards[arm])

  def get_regrets(self) -> np.ndarray:
    """Collect the regret suffered by each agent in the interaction.

    Returns:
      regrets: A 1D float array with as many entries as provided agents.
    """
    diff_vector = np.max(self.bandit.means) - self.bandit.means
    return np.array([np.sum(diff_vector * a.counts) for a in self.agents])

  def get_pull_freqs(self) -> list:
    """Collect each agent's arm pull frequencies.

    :returns:
      A list of 1D float arrays with the frequencies for each agent.
    """
    return [a.counts / np.sum(a.counts) for a in self.agents]

  def get_next_pull_probs(self) -> list:
    """Collect each agent's next arm pull probabilities.

    :returns:
      A list of 1D float arrays with next arm pull probabilities for each agent.
    """
    return [a.pull_probs() for a in self.agents]

  def reset(self):
    """Reset all agents."""
    for agent in self.agents:
      agent.reset()


class NestedBanditGame:
  """In a nested bandit game a set of nested agents interact with a MAB."""

  def __init__(self, bandit: AbstractBandit, *agents: NestedAgent):
    """Initialize a bandit with nested agent interaction.

    Args:
      bandit: The bandit with which the agents interact.
      agents: The nested agents that interact with the bandit. Each NestedAgent
          has its own player and set of recommendation engines. Each of the
          agents should have the same number of recommendation engines.
    """
    n_recsys = len(agents[0].recsys)
    for agent in agents:
      if len(agent.recsys) != n_recsys:
        raise ValueError('All agents should be comparing the same number of '
                         'recommendation engines.')
    self.n_recsys = n_recsys
    self.bandit = bandit
    self.agents = agents

  def play(self, n_rounds: int = 1):
    """Let the nested agents interact with the bandit.

    Args:
      n_rounds: The number of interaction rounds.
    """
    for _ in tqdm(range(n_rounds)):
      rewards = self.bandit.sample().squeeze(0)
      for agent in self.agents:
        recommendations, player_choice = agent.select_arms()
        agent.update(recommendations, player_choice, rewards)

  def get_regrets(self) -> Tuple[np.ndarray, np.ndarray]:
    """Collect the regrets suffered by all nested agent and recsys.

    This method works both with and without the `NO-OP` option and even sets
    of agents with combination of both. Note that the regret for `NO-OP` zero
    reward is always appended to the `diff_vector` so that we can compute the
    regret of recommender systems which receive reward zero when not selected.

    Returns:
      player_regrets: A 1D float array with as many entries as players (agents)
      recsys_regrets: A 2D float array of the shape (num_agents, num_recsys)
    """
    def _stack_noop(arr, c, _noop_idx):
      return np.hstack([arr[:_noop_idx], c, arr[_noop_idx:]])

    def _get_diff_vector(means, _noop_value, _noop_idx):
      augm_means = _stack_noop(means, _noop_value, _noop_idx)
      return np.max(augm_means) - augm_means

    player_regrets, recsys_regrets = [], []
    for agent in self.agents:
      noop_option = agent.player.noop_option
      noop_idx = agent.player.noop_arm_idx
      noop_value = agent.player.noop_value

      diff_player = _get_diff_vector(self.bandit.means, noop_value, noop_idx)
      # compute the total number of pulls of each arm the player has executed
      arm_pulls = np.sum(np.vstack([rs.counts for rs in agent.recsys]), axis=0)
      # attach the number of times player chose NO-OP (0 if NO-OP unavailable)
      n_noop = agent.player.counts[noop_idx] if noop_option else 0
      arm_pulls = _stack_noop(arm_pulls, n_noop, noop_idx)
      player_regrets.append(np.sum(diff_player * arm_pulls))

      rs_regrets = []
      total_pulls = np.sum(arm_pulls)
      diff_recsys = _get_diff_vector(self.bandit.means, 0, noop_idx)  # zero reward for no interaction
      for rs in agent.recsys:
        # add the total number of times the `rs` was not selected to the counts
        rs_counts = _stack_noop(
          rs.counts, total_pulls - np.sum(rs.counts), noop_idx)
        # regret for not being chosen is the same as obtaining zero reward
        rs_regrets.append(np.sum(diff_recsys * rs_counts))
      recsys_regrets.append(np.array(rs_regrets))

    return np.array(player_regrets), np.array(recsys_regrets)

  def get_pull_freqs(self) -> list:
    """Collect frequencies of user interactions with each recsys (incl. NO-OP).

    :returns:
      A list of 1D float arrays with the frequencies for each agent (the
      array sizes may vary depending on whether NO-OP is available).
    """
    return [a.player.counts / np.sum(a.player.counts) for a in self.agents]

  def get_next_pull_probs(self) -> list:
    """Collect user probabilities of selecting each recommender (incl. NO-OP) in
    the next round.

    :returns:
      A list of 1D float arrays with next arm pull probabilities for each agent
      (the array sizes may vary depending on whether NO-OP is available).
    """
    return [a.player.pull_probs() for a in self.agents]

  def reset(self):
    """Reset all agents."""
    for agent in self.agents:
      agent.reset()


class ContextBanditGame:
  def __init__(self,
               bandit: Union[LinContextBandit, DatasetBandit],
               agent: Union[AbstractContextual, TwoStage]):
    self.bandit = bandit
    self.agent = agent
    self.is_twostage = isinstance(agent, TwoStage)

  def play(self, n_rounds: int = 1) -> Dict[Text, np.ndarray]:
    results = {'return': np.full(n_rounds, np.nan),
               'regret': np.full(n_rounds, np.nan)}
    if self.is_twostage:
      results.update({
        'best-best_nom': np.full(n_rounds, np.nan),
        'best_nom-selected': np.full(n_rounds, np.nan),
        'best_in_pool-nom': np.full((n_rounds, len(self.agent.nominators)),
                                    np.nan),
      })
    for i in tqdm(range(n_rounds)):
      ctxt = self.bandit.get_context()
      mean_rwrd = self.bandit.mean_reward(ctxt)
      reward = self.bandit.sample_rewards(ctxt).squeeze(0)
      best = np.max(mean_rwrd)
      if self.is_twostage:
        arm, nom_arms = self.agent.select_arm(ctxt)
        best_nom = np.max(mean_rwrd[nom_arms])
        results['best-best_nom'][i] = best - best_nom
        results['best_nom-selected'][i] = best_nom - mean_rwrd[arm]
        pool = [np.max(mean_rwrd[nom.exp_arms]) - mean_rwrd[nom_arm]
                for nom, nom_arm in zip(self.agent.nominators, nom_arms)]
        results['best_in_pool-nom'][i, :] = np.array(pool)
      else:
        arm = self.agent.select_arm(ctxt)
      self.agent.update(arm, reward[arm], ctxt)
      results['return'][i] = mean_rwrd[arm]
      results['regret'][i] = best - mean_rwrd[arm]

    return results

  def reset(self):
    """Reset all agents."""
    self.agent.reset()
