"""Multi armed bandits."""
from typing import Callable

import numpy as np

import utils


class AbstractBandit:
  """Abstract bandit class."""

  def __init__(self,
               n_arms: int,
               means: np.ndarray = None,
               seed: int = None):
    """Initialize a bandit.

    Args:
      n_arms: The number of arms.
      means: The means of the reward distribution of each arm.
      seed: Random seed
    """
    if means is not None and means.size != n_arms:
      raise ValueError()
    self.rng = np.random.RandomState(seed)
    self.means = means
    self.n_arms = n_arms

  def sample(self, n: int = 1) -> np.ndarray:
    raise NotImplementedError()

  def pull_arm(self, arm: int) -> np.ndarray:
    raise NotImplementedError()

  def best_arm(self) -> np.ndarray:
    """The arm with the highest weights reward."""
    return np.argmax(self.means)


class BernoulliBandit(AbstractBandit):
  """A multi armed bandit with Bernoulli reward distributions."""

  def __init__(self, means: np.ndarray = None, **kwargs):
    """Initialize a Bernoulli bandit.

    Args:
      means: The means of the Bernoulli distributions. Defaults to samples from
          a Beta(1, 1) distribution.
      kwargs: Additional keyword arguments passed to AbstractBandit.
    """
    super().__init__(means=means, **kwargs)
    if self.means is None:
      self.means = self.rng.beta(1, 1, size=self.n_arms)

  def sample(self, n: int = 1) -> np.ndarray:
    """Sample rewards from the arms.

    Args:
      n: Number of samples to draw.

    Returns:
      ndarray of size (n, number of arms) with the sampled rewards
    """
    return self.rng.binomial(1, p=self.means, size=(n, self.means.size))

  def pull_arm(self, arm: int) -> np.ndarray:
    """Pull an arm to observe the reward.

    Args:
      arm: The index of the arm to be pulled.

    Returns:
      the reward of the arm
    """
    return self.rng.binomial(1, p=self.means[arm])


class GaussianBandit(AbstractBandit):
  """A multi armed bandit with Gaussian reward distributions."""

  def __init__(self,
               means: np.ndarray = None,
               std: float = 1.0,
               **kwargs):
    """Initialize a Gaussian bandit.

    Args:
      means: The means of the Gaussian distributions. Defaults to samples from
          a standard normal distribution.
      std: The fixed standard deviation of the Gaussian rewards.
      kwargs: Additional keyword arguments passed to AbstractBandit.
    """
    super().__init__(means=means, **kwargs)
    self.std = std
    if self.means is None:
      self.means = self.rng.normal(size=self.n_arms)

  def sample(self, n: int = 1) -> np.ndarray:
    """Sample rewards from the arms.

    Args:
      n: Number of samples to draw.

    Returns:
      ndarray of size (n, number of arms) with the sampled rewards
    """
    noise = self.std * self.rng.normal(size=(n, self.means.size))
    return np.atleast_2d(self.means) + noise

  def pull_arm(self, arm: int) -> np.ndarray:
    """Pull an arm to observe the reward.

    Args:
      arm: The index of the arm to be pulled.

    Returns:
      the reward of the arm
    """
    return self.means[arm] + self.std * self.rng.normal()


class LinContextBandit:
  def __init__(self,
               n_arms: int,
               context_sampler: Callable[[], np.ndarray],
               true_param: np.ndarray,
               temperature: float = 1.0,
               seed: int = None):

    self.rng = np.random.RandomState(seed)
    self.n_arms = n_arms
    self.temperature = temperature
    self.true_param = true_param
    self.ctxt_sampler = context_sampler

  def sample_contexts(self, n: int = 1) -> np.ndarray:
    return np.array([self.get_context() for _ in range(n)])

  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    raise NotImplementedError()

  def get_context(self) -> np.ndarray:
    return self.ctxt_sampler()

  def mean_reward(self, ctxt: np.ndarray) -> float:
    return ctxt @ self.true_param


class LinGaussBandit(LinContextBandit):
  def __init__(self, reward_std: float = 0.1, **kwargs):
    self.reward_std = reward_std
    super().__init__(**kwargs)

  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    if ctxt.ndim == 2:
      ctxt = ctxt[None, ...]
    noise = self.reward_std * self.rng.normal(size=(len(ctxt), self.n_arms))
    return ctxt @ self.true_param + noise


class LinBernBandit(LinContextBandit):
  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    if ctxt.ndim == 2:
      ctxt = ctxt[None, ...]
    logits = ctxt @ self.true_param / self.temperature
    probs = 1.0 / (1.0 + np.exp(-logits))
    return self.rng.binomial(1, p=probs).astype(logits.dtype)


class LinCategBandit(LinContextBandit):
  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    if ctxt.ndim == 2:
      ctxt = ctxt[None, ...]
    logits = ctxt @ self.true_param / self.temperature
    samples = logits + self.rng.gumbel(size=logits.shape)
    samples = samples == np.max(samples, axis=-1, keepdims=True)
    return samples.astype(logits.dtype)


class LinPoissBandit(LinContextBandit):
  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    if ctxt.ndim == 2:
      ctxt = ctxt[None, ...]
    log_rate = ctxt @ self.true_param / self.temperature
    return self.rng.exponential(scale=np.exp(-log_rate))


class HalfGaussBandit(LinGaussBandit):
  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    return np.maximum(0, super().sample_rewards(ctxt))


class DatasetBandit:
  def __init__(self,
               ctxts: np.ndarray,
               rewards: np.ndarray,
               n_arms: int,
               seed: int = None):
    self.n_arms = n_arms
    self.rng = np.random.RandomState(seed)
    self.ctxts = utils.whiten(ctxts)
    self.ctxt_to_reward = {self.ctxts[i].tobytes(): rewards[i]
                           for i in range(ctxts.shape[0])}

  def get_context(self) -> np.ndarray:
    sample = self.rng.choice(self.ctxts.shape[0], self.n_arms)
    return self.ctxts[sample]

  def sample_rewards(self, ctxt: np.ndarray) -> np.ndarray:
    if ctxt.ndim == 2:
      ctxt = ctxt[None, ...]

    rewards = np.zeros([ctxt.shape[0], ctxt.shape[1]])
    for i in range(ctxt.shape[0]):
      for j in range(ctxt.shape[1]):
        rewards[i, j] = self.ctxt_to_reward[ctxt[i, j].tobytes()]

    return rewards

  def mean_reward(self, ctxt: np.ndarray) -> np.ndarray:
    return self.sample_rewards(ctxt).squeeze(0)


class WikiBandit(DatasetBandit):
  def __init__(self, n_arms: int, n_features: int, seed: int = None):
    ctxts, rewards = utils.read_bandit_dataset('wiki10-31k')
    assert n_features <= ctxts.shape[1], "Not enough features in dataset"
    ctxts = ctxts[:, :n_features]
    # Pick the 22146-th index since it results in roughly 10% density of positive rewards.
    rewards = rewards[:, 22146].toarray().flatten()
    super().__init__(ctxts=ctxts,
                     rewards=rewards,
                     n_arms=n_arms,
                     seed=seed)


class AmazonBandit(DatasetBandit):
  def __init__(self, n_arms: int, n_features: int, seed: int = None):
    ctxts, rewards = utils.read_bandit_dataset('amazoncat-13k-bert')
    assert n_features <= ctxts.shape[1], "Not enough features in dataset"
    ctxts = ctxts[:, :n_features]
    # Pick the 7892-th index since it results in roughly 10% density of positive rewards.
    rewards = rewards[:, 7892].toarray().flatten()
    super().__init__(ctxts=ctxts,
                     rewards=rewards,
                     n_arms=n_arms,
                     seed=seed)
