"""Variants of multi-armed bandit agents."""

from collections import deque
from typing import Callable, Sequence, Tuple

import numpy as np
from scipy.special import softmax


def max_idxs_with_tol(x: np.ndarray, tol: float = 1e-6) -> np.ndarray:
  """Compute indices of elements less than `tol` away from the maximum. Useful
  for preventing numerical inaccuracies from affecting the uniform tie breaking
  implementation used throghout this file.

  :param x: A 1D array to take the maximum over.
  :param tol: Tolerance for the distance from maximum.
  :return: A 1D boolean array flagging elements less than `tol` away from
    `max(x)` by `True`.
  """
  return np.arange(len(x))[x > np.max(x) - tol]


class AbstractAgent:
  """Abstract agent class for vanilla multi-armed bandits.

  The `noop_option` attribute should be used by implementing classes in order
  to prevent sampling `NO-OP` during the exploratory steps as `NO-OP` is
  assumed to have known zero weights.
  """
  noop_arm_idx = 0

  def __init__(self, n_arms: int,
               init_mean: Callable[[int], np.ndarray] = None,
               noop_option: bool = False,
               noop_value: float = 0.0,
               seed: int = None):
    """Initialize agent.

    Args:
      n_arms: Number of arms.
      init_mean: A function initializing the means of each arm.
      noop_option: Whether or not a special no operation arm is added.
      noop_value: The fixed reward value of the noop arm.
          Ignored if `noop_option` is False.
      seed: Random seed.
    """
    self.rng = np.random.RandomState(seed)

    self.noop_option = noop_option
    self.noop_value = noop_value
    self.n_arms = n_arms + self.noop_option
    self.arm_idxs = np.arange(self.n_arms)
    if self.noop_option:
      self.exp_arm_idxs = np.delete(self.arm_idxs, self.noop_arm_idx)
    else:
      self.exp_arm_idxs = self.arm_idxs

    self.init_mean = init_mean
    if self.init_mean is None:
      self.init_mean = lambda _n_arms: np.zeros(_n_arms)

    self.counts = None  # initialized in reset()
    self.means = None  # initialized in reset()

    self.reset()

  def get_step(self) -> int:
    return 1 + np.sum(self.counts)

  # def select_arm(self) -> int:
  #   raise NotImplementedError
  def select_arm(self) -> int:
    return self.rng.choice(self.arm_idxs, p=self.pull_probs())
  # TODO: this impl of `select_arm` leads to strange behaviour -> debug

  def pull_probs(self) -> np.ndarray:
    raise NotImplementedError()

  def update(self, arm: int, reward: float):
    """Update the running count and weights for the given arm.

    Args:
      arm: The selected arm index.
      reward: The achieved reward.
    """
    self.counts[arm] += 1
    self.means[arm] *= self.counts[arm] - 1
    self.means[arm] += reward
    self.means[arm] /= float(self.counts[arm])

  def reset(self):
    """Reset all counts and means."""
    self.counts = np.zeros(self.n_arms, dtype='int')
    self.means = self.init_mean(self.n_arms)
    if self.noop_option:
      self.means[self.noop_arm_idx] = self.noop_value

  def __repr__(self):
    return f'{self.__class__.__name__} n_arms={self.n_arms} ' \
           f'noop_option={self.noop_option}'


class Greedy(AbstractAgent):
  """A greedy agent.

  This agent chooses arms uniformly at random for the first `delay` steps and
  afterwards always selects uniformly at random among the arms with the highest
  empirical weights reward.
  """

  def __init__(self, delay: int, **kwargs):
    """Initialize a greedy agent.

    Args:
      delay: Number of steps of uniformly random exploration.
      kwargs: Additional keyword arguments passed to AbstractAgent.
    """
    self.delay = delay
    super().__init__(**kwargs)

  def select_arm(self) -> int:
    if self.get_step() > self.delay:
      return self.rng.choice(max_idxs_with_tol(self.means))
    else:
      return self.rng.choice(self.exp_arm_idxs)

  def pull_probs(self) -> np.ndarray:
    probs = np.zeros_like(self.arm_idxs, dtype='float')

    if self.get_step() > self.delay:
      probs[max_idxs_with_tol(self.means)] = 1.0
    else:
      probs[self.exp_arm_idxs] = 1.0
    return probs / np.sum(probs)

  def __repr__(self):
    s = super().__repr__()
    idx = len(self.__class__.__name__)
    return s[:idx] + f' delay={self.delay}' + s[idx:]


class EpsGreedy(AbstractAgent):
  """An epsilon-greedy agent.

  This agent chooses an arm uniformly at random with some probability epsilon
  and otherwise greedily selects uniformly at random among the arms with the
  highest empirical weights reward.
  """

  def __init__(self, eps_scaling: float, **kwargs):
    """Initialize an epsilon greedy agent.

    Args:
      eps_scaling: The epsilon scaling factor. The effective epsilon at time t
          is min(1, eps * num_arms / t).
      kwargs: Additional keyword arguments passed to AbstractAgent.
    """
    self.eps_scaling = eps_scaling
    super().__init__(**kwargs)

  def select_arm(self) -> int:
    step = self.get_step()
    eps = np.minimum(1, self.eps_scaling * len(self.exp_arm_idxs) / float(step))
    if self.rng.binomial(1, eps) == 0:
      # return self.rng.choice(self.arm_idxs[max_idxs_with_tol(self.means)])
      return self.rng.choice(max_idxs_with_tol(self.means))
    else:
      return self.rng.choice(self.exp_arm_idxs)

  def pull_probs(self) -> np.ndarray:
    step = self.get_step()
    eps = np.minimum(1, self.eps_scaling * len(self.exp_arm_idxs) / float(step))
    top_idxs = max_idxs_with_tol(self.means)

    probs = np.zeros_like(self.arm_idxs, dtype='float')
    probs[self.exp_arm_idxs] = eps / len(self.exp_arm_idxs)
    probs[top_idxs] += (1 - eps) / len(top_idxs)

    # TODO: debug code
    # if self.rng.binomial(1, eps) == 0:
    # #  self.rng.choice(self.arm_idxs[max_idx_with_tol(self.means)])
    #   self.rng.choice(max_idxs_with_tol(self.means))
    # else:
    #   self.rng.choice(self.exp_idxs)

    return probs

  def __repr__(self):
    s = super().__repr__()
    idx = len(self.__class__.__name__)
    return s[:idx] + f' eps_scaling={self.eps_scaling}' + s[idx:]


class ForgetfulGreedy(AbstractAgent):
  """A forgetful epsilon-greedy agent.

  This agent chooses arms one after another (round robin) until each arm was
  selected a fixed number of times. Afterwards it chooses greedily uniformly at
  random among the arms with the highest weights reward within a finite window of
  the past `slide_window` steps. It 'forgets' rewards it has seen before.
  """

  def __init__(self, n_exp_rounds: int, slide_window: int, **kwargs):
    """Initialize a forgetful greedy agent.

    Args:
      n_exp_rounds: Number of exlporation rounds up front. This is the number
          of times each arm is chosen (not the number of steps). That is,
          exploration lasts for n_exp_rounds * n_arms steps.
      slide_window: The number of latest rewards to remember in the empirical
          weights reward computation.
      kwargs: Additional keyword arguments passed to AbstractAgent.
    """
    self.n_exp_rounds = n_exp_rounds
    self.slide_window = slide_window
    self.window_rewards = None  # initialized in reset()
    super().__init__(**kwargs)  # slide_window must be initialised b4 this

  def select_arm(self) -> int:
    step = self.get_step()
    if step > self.n_exp_rounds * len(self.exp_arm_idxs):
      # return self.rng.choice(self.arm_idxs[max_idxs_with_tol(self.means)])
      return self.rng.choice(max_idxs_with_tol(self.means))
    else:
      return self.exp_arm_idxs[(step - 1) % len(self.exp_arm_idxs)]

  def pull_probs(self) -> np.ndarray:
    step = self.get_step()
    probs = np.zeros_like(self.arm_idxs, dtype='float')
    if step > self.n_exp_rounds * len(self.exp_arm_idxs):
      probs[max_idxs_with_tol(self.means)] = 1.0
    else:
      probs[self.exp_arm_idxs[(step - 1) % len(self.exp_arm_idxs)]] = 1.0
    return probs / np.sum(probs)

  def update(self, arm: int, reward: float):
    self.counts[arm] += 1
    # append in deque automatically removes the 'first-in' at capacity
    self.window_rewards[arm].append(reward)
    self.means[arm] = np.mean(self.window_rewards[arm])

  def reset(self):
    super().reset()
    self.window_rewards = {
      arm: deque(maxlen=self.slide_window) for arm in range(self.n_arms)}

  def __repr__(self):
    s = super().__repr__()
    idx = len(self.__class__.__name__)
    return s[:idx] + f' n_exp_rounds={self.n_exp_rounds}' \
                     f' slide_window={self.slide_window}' + s[idx:]


class ExploreThenCommit(AbstractAgent):
  """An explore-then-commit agent.

  This agent chooses arms one after another (round robin) until each arm was
  selected a fixed number of times. Afterwards it chooses greedily uniformly at
  random among the arms with the highest empirical weights reward.
  """

  def __init__(self, n_exp_rounds: int, **kwargs):
    """Initialize an explore-then-commit agent.

    Args:
      n_exp_rounds: Number of exlporation rounds up front. This is the number
          of times each arm is chosen (not the number of steps). That is,
          exploration lasts for n_exp_rounds * n_arms steps.
      kwargs: Additional keyword arguments passed to AbstractAgent.
    """
    self.n_exp_rounds = n_exp_rounds
    self.commit_idx = None
    super().__init__(**kwargs)

  def select_arm(self) -> int:
    step = self.get_step()
    commit_round = 1 + self.n_exp_rounds * len(self.exp_arm_idxs)
    if step >= commit_round:
      if step == commit_round:
        self.commit_idx = self.arm_idxs[max_idxs_with_tol(self.means)]
      return self.rng.choice(self.commit_idx)
    else:
      return self.exp_arm_idxs[(step - 1) % len(self.exp_arm_idxs)]

  def pull_probs(self) -> np.ndarray:
    step = self.get_step()
    probs = np.zeros_like(self.arm_idxs, dtype='float')
    commit_round = 1 + self.n_exp_rounds * len(self.exp_arm_idxs)
    if step >= commit_round:
      if step == commit_round:
        self.commit_idx = max_idxs_with_tol(self.means)
      probs[self.commit_idx] = 1.0
    else:
      probs[self.exp_arm_idxs[(step - 1) % len(self.exp_arm_idxs)]] = 1.0
    return probs / np.sum(probs)

  def __repr__(self):
    s = super().__repr__()
    idx = len(self.__class__.__name__)
    return s[:idx] + f' n_exp_rounds={self.n_exp_rounds}' + s[idx:]


class UpperConfidence(AbstractAgent):
  """An upper confidence bound (UCB) agent.

  This agent selects each arm once (round robin) and keeps track of upper
  confidence bounds for the reward of each arm. It then always selects greedily
  uniformly at random among the arms with the currently highest upper
  confidence bound.
  """

  def __init__(self, delta_func=None, **kwargs):
    """Initialize an upper confidence bound agent.

    Args:
      delta_func: The function determining delta from the round count t to be
          used in the upper confidence bound computation at each round. Will
          default to 'standard formula'.
      kwargs: Additional keyword arguments passed to AbstractAgent.
    """
    self.delta_func = delta_func or (lambda t: 1 / (1 + t * np.log(t) ** 2))
    super().__init__(**kwargs)

  def select_arm(self):
    step = self.get_step()
    if step > self.n_arms:
      delta = self.delta_func(step)
      upper_bound = self.means + np.sqrt(2 * np.log(1 / delta) /
                                         (self.counts * (step - 1)))
      return self.rng.choice(max_idxs_with_tol(upper_bound))
    else:
      return self.exp_arm_idxs[step - 1]

  def pull_probs(self) -> np.ndarray:
    step = self.get_step()
    probs = np.zeros_like(self.arm_idxs, dtype='float')
    if step > len(self.exp_arm_idxs):
      delta = self.delta_func(step)
      upper_bound = self.means + np.sqrt(2 * np.log(1 / delta) /
                                         (self.counts * (step - 1)))
      probs[max_idxs_with_tol(upper_bound)] = 1.0
    else:
      probs[self.exp_arm_idxs[step - 1]] = 1.0
    return probs / np.sum(probs)


class NestedAgent:
  """A nested agent has two components: 'proposers' and a 'chooser'."""

  def __init__(self, player: AbstractAgent, *recsys):
    """Initialize a nested agent.

    Args:
      player: A MAB agent inheriting form AbstractAgent representing the player
          (i.e., the chooser in the nested agent).
      recsys: Any number of recommender systems, also instances of
          AbstractAgent, representing the competing recommender systems, i.e.,
          the proposers.
    """
    if player.noop_option and player.n_arms != len(recsys) + 1:
      raise ValueError('Player must have one arm per recsys, +1 for NO-OP')
    self.player = player
    self.recsys = recsys

  def select_arms(self):
    """Select arms for the player and the recommender systems.

    At the moment, the player selects its action based on past experience only,
    but we could also consider a contextual bandit (?!) where its action could
    be conditioned on the arms that are being presented.

    N: This set-up is similar to the one used in Aridor et al. (2019)
    """
    recs = [rs.select_arm() for rs in self.recsys]
    player_choice = self.player.select_arm()
    return recs, player_choice

  def update(self,
             recs: Sequence[AbstractAgent],
             player_choice: int,
             rewards: Sequence[float]):
    """Update the running counts and means for the selected arms.

    The player is updated acoording to its choice of recommender systems and
    the achieved reward. Among the recommender systems, only the one selected
    by the player is updated.

    Args:
      recs: The available recommender systems.
      player_choice: The arm chosen by the player.
      rewards: The rewards for the choices made by the recommenders.
    """
    if self.player.noop_option:
      self._update_w_noop(recs, player_choice, rewards)
    else:
      self._update_wo_noop(recs, player_choice, rewards)

  def _update_wo_noop(self, recs, player_choice, rewards, select_rec=None):
    select_rec = player_choice if select_rec is None else select_rec
    rec_arm = recs[select_rec]

    # update estimate of how good is the recommender
    self.player.update(player_choice, rewards[rec_arm])
    # update **only** the selected recommender system's information
    self.recsys[select_rec].update(rec_arm, rewards[rec_arm])

  def _update_w_noop(self, recs, player_choice, rewards):
    if player_choice != AbstractAgent.noop_arm_idx:
      select_rec = player_choice
      select_rec -= 1 if player_choice >= AbstractAgent.noop_arm_idx else 0

      self._update_wo_noop(recs, player_choice, rewards, select_rec)
    else:
      # ensures player's pull counts are correct
      self.player.update(player_choice, self.player.noop_value)

  def reset(self):
    self.player.reset()
    for rs in self.recsys:
      rs.reset()

  def __repr__(self):
    return f'{self.__class__.__name__} player={self.player}] ' \
           f'recsys={self.recsys}'


class AbstractContextual:
  def __init__(self,
               n_features: int,
               init_weights: Callable[[int], np.ndarray] = None,
               active_features: np.ndarray = None,
               exp_arms: np.ndarray = None,
               seed: int = None):
    """Initialise agent.

    :param n_features: Dimension of the feature space.
    :param init_weights: Returns initial value for weights of given dimension.
    :param active_features: Indices of the features that are to be used by the
      algorithm (all others are ignored). Uses all features if `None`.
    :param exp_arms: Indices of the arms to be considered when `select_arm` is
      called. If `None`, all arms are considered.
    :param seed: Seed for the instance's random own random number generator.
    """
    self.rng = np.random.RandomState(seed)

    self.n_features = n_features
    if init_weights is None:
      init_weights = lambda d: np.zeros(d)
    self.init_weights = init_weights

    self.exp_arms = exp_arms
    if active_features is None:
      active_features = np.arange(self.n_features)
    self.active_features = active_features

    self.step = None
    self.weights = None
    self.reset()

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    raise NotImplementedError()

  def select_arm(self, ctxt: np.ndarray) -> int:
    raise NotImplementedError()
    # TODO: could be reimplemented as one-linear on top of `pull_probs`

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    raise NotImplementedError()

  def reset(self):
    self.step = 1
    self.weights = self.init_weights(len(self.active_features))

  def __repr__(self):
    return f'{self.__class__.__name__} n_features={self.n_features} ' \
           f'n_active_fs={len(self.active_features)}'


class PolicyGradient(AbstractContextual):
  @staticmethod
  def log_lr_schedule(scaling: float = 2.0):
    return lambda t: scaling / np.log(1.0 + t)

  @staticmethod
  def root_lr_schedule(scaling: float = 2.0):
    return lambda t: scaling / np.sqrt(1.0 + t)

  def __init__(self,
               lr_schedule: Callable[[int], float] = None,
               random_greedy: bool = True,
               eps_scaling: float = 0.0,
               **kwargs):
    if lr_schedule is None:
      lr_schedule = self.log_lr_schedule()

    self.random_greedy = random_greedy
    self.eps_scaling = eps_scaling
    self.lr_schedule = lr_schedule
    super().__init__(**kwargs)

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    ctxt = ctxt[..., self.active_features]

    probs = np.zeros(len(ctxt))
    n_exp_arms = len(ctxt) if self.exp_arms is None else len(self.exp_arms)
    eps = np.minimum(1, self.eps_scaling * n_exp_arms / float(self.step))

    idxs = np.arange(len(ctxt)) if self.exp_arms is None else self.exp_arms
    probs[idxs] += eps / len(idxs)

    logits = ctxt @ self.weights
    if self.random_greedy:
      probs[idxs] += (1 - eps) * softmax(logits)[idxs]
    else:
      top_idxs = idxs[max_idxs_with_tol(logits[idxs])]
      probs[top_idxs] += (1 - eps) / len(top_idxs)

    return probs

  def select_arm(self, ctxt: np.ndarray) -> int:
    ctxt = ctxt[..., self.active_features]

    z = 0
    if self.eps_scaling > 0:
      n_exp_arms = len(ctxt) if self.exp_arms is None else len(self.exp_arms)
      eps = np.minimum(1, self.eps_scaling * n_exp_arms / float(self.step))
      z = self.rng.binomial(1, eps)

    if z == 0:
      # can take max over logits, softmax won't change the ordering
      logits = ctxt @ self.weights
      if self.random_greedy:
        if self.exp_arms is not None:
          logits[np.setdiff1d(range(len(logits)), self.exp_arms)] = -1e9
        return int(np.argmax(logits + self.rng.gumbel(size=logits.shape)))
      else:
        if self.exp_arms is None:
          top_idxs = max_idxs_with_tol(logits)
        else:
          top_idxs = self.exp_arms[max_idxs_with_tol(logits[self.exp_arms])]
        return self.rng.choice(top_idxs)
    else:
      return self.rng.choice(np.arange(len(ctxt)))

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    """Update based on `exp_idxs` with `arm` added if not one of them."""
    ctxt = ctxt[..., self.active_features]

    if reward != 0:  # zero gradient if reward is zero
      if self.exp_arms is None:
        idxs = np.arange(len(ctxt))
      else:
        idxs = np.unique(np.concatenate((self.exp_arms, [arm])))

      probs = softmax(ctxt[idxs] @ self.weights)
      g = ctxt[arm] - np.sum(probs[:, None] * ctxt[idxs], axis=0)
      g *= -reward * probs[arm == idxs]

      self.weights -= self.lr_schedule(self.step) * g

    self.step += 1

  def __repr__(self):
    s = super().__repr__()
    idx = len(self.__class__.__name__)
    return s[:idx] + f' random_greedy={self.random_greedy}' \
                     f' eps_scaling={self.eps_scaling}' + s[idx:]


class LinUniform(AbstractContextual):
  """Uniform baseline, useful for sanity checks."""
  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    ctxt = ctxt[..., self.active_features]

    if self.exp_arms is None:
      probs = np.ones(len(ctxt)) / len(ctxt)
    else:
      probs = np.zeros(len(ctxt))
      probs[self.exp_arms] = 1.0 / len(self.exp_arms)
    return probs

  def select_arm(self, ctxt: np.ndarray) -> int:
    ctxt = ctxt[..., self.active_features]

    if self.exp_arms is None:
      return self.rng.choice(np.arange(len(ctxt)))
    else:
      return self.rng.choice(self.exp_arms)

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    # ctxt = ctxt[..., self.active_features]
    self.step += 1


class AbstractRidge(AbstractContextual):
  """Abstract agent class for linear contextual bandits."""
  def __init__(self,
               n_features: int,
               regulariser: float = 1.0,
               scale_regulariser_by_dim: bool = False,
               active_features: np.ndarray = None,
               **kwargs):
    """Initialise an instance of abstrace ridge contextual bandit.

    :param n_features: Dimension of the feature space.
    :param regulariser: Ridge regression regularisation parameter.
    :param scale_regulariser_by_dim: If `True`, multiplies the regulariser by
      `n_features`. Useful when experimenting with varying `n_features`.
    :param active_features: Indices of the features that are to be used by the
      algorithm (all others are ignored). Uses all features if `None`.
    :param kwargs: Parameters for the `AbstractContextual` initialiser.
    """
    if scale_regulariser_by_dim:
      mult = n_features if active_features is None else len(active_features)
      regulariser *= mult
    self.regulariser = regulariser
    self.scale_regulariser_by_dim = scale_regulariser_by_dim

    super().__init__(n_features, active_features=active_features, **kwargs)

  def select_arm(self, ctxt: np.ndarray) -> int:
    raise NotImplementedError()

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    ctxt = ctxt[arm][..., self.active_features]

    a = self.covar @ ctxt
    self.covar -= np.outer(a, a) / (1 + np.sum(ctxt * a))
    self.weights += (self.covar @ ctxt) * (reward - self.weights @ ctxt)

    self.step += 1

  def reset(self):
    super().reset()
    self.covar = np.eye(len(self.active_features)) / self.regulariser


class GreedyRidge(AbstractRidge):
  def select_arm(self, ctxt: np.ndarray) -> int:
    ctxt = ctxt[..., self.active_features]

    pred = self.get_indices(ctxt)
    if self.exp_arms is None:
      return self.rng.choice(max_idxs_with_tol(pred))
    else:
      top_idxs = self.exp_arms[max_idxs_with_tol(pred[self.exp_arms])]
      return self.rng.choice(top_idxs)

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    ctxt = ctxt[..., self.active_features]

    pred = self.get_indices(ctxt)
    if self.exp_arms is None:
      top_idxs = max_idxs_with_tol(pred)
    else:
      top_idxs = self.exp_arms[max_idxs_with_tol(pred[self.exp_arms])]

    probs = np.zeros(len(ctxt))
    probs[top_idxs] = 1.0 / len(top_idxs)
    return probs

  def get_indices(self, ctxt: np.ndarray):
    if ctxt.shape[-1] != len(self.active_features):
      ctxt = ctxt[..., self.active_features]
    return ctxt @ self.weights


class LinUCB(AbstractRidge):
  """UCB for linear contextual bandits.

  Uses an ellipsoidal confidence set with a regularised empirical covariance
  estimate defining the distance `[a^T (regulariser * I + X^T @ X)^(-1) b]^0.5`
  between any `a, b` in `R^n_features`, and `upper_bound_fn` returning the
  boundary of a ball around the ridge regression parameter estimate in each
  round. The action is then selected by maximising inner product between the
  action-context feature vector and a parameter in this ellipsoid.
  """
  @staticmethod
  def adapt_ub_schedule(d, reg):
    """See eq. (19.8) in Bandit algorithms by Lattimore and Szepesvari."""
    def schedule(t):
      ret = 2 * np.log(t) + d * np.log(1 + t / (d * reg))
      return np.sqrt(reg) + np.sqrt(ret)
    return schedule

  @staticmethod
  def const_ub_schedule(alpha):
    return lambda _: alpha

  def __init__(
      self, upper_bound_fn: str = 'adapt', alpha: float = 1.0, **kwargs):
    """Initialise a LinUCB agent.

    :param upper_bound_fn: Either `'adapt'` or `'const'`, referring to the type
      of the function called at each round to obtain the uncertainty bonus
      multiplier.
    :param alpha: Only used when `upper_bound_fn == 'const'`. It is the constant
      by which the uncertainty bounus is multiplied.
    :param kwargs: Additional keyword arguments passed to `AbstractRidge`.
    """
    super().__init__(**kwargs)

    self.alpha = alpha
    if upper_bound_fn == 'adapt':
      self.upper_bound_fn = \
        LinUCB.adapt_ub_schedule(len(self.active_features), self.regulariser)
    elif upper_bound_fn == 'const':
      self.upper_bound_fn = LinUCB.const_ub_schedule(alpha)
    else:
      raise NotImplementedError(upper_bound_fn)

  def _ucb(self, ctxt):
    if ctxt.shape[-1] != len(self.active_features):
      ctxt = ctxt[..., self.active_features]
    b = self.upper_bound_fn(self.step)
    bound = b * np.sqrt(np.sum((ctxt @ self.covar) * ctxt, axis=-1))
    bound += ctxt @ self.weights
    return bound

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    ctxt = ctxt[..., self.active_features]

    bound = self._ucb(ctxt)
    if self.exp_arms is None:
      top_idxs = max_idxs_with_tol(bound)
    else:
      top_idxs = self.exp_arms[max_idxs_with_tol(bound[self.exp_arms])]

    probs = np.zeros(len(ctxt))
    probs[top_idxs] = 1.0 / len(top_idxs)
    return probs

  def select_arm(self, ctxt: np.ndarray) -> int:
    ctxt = ctxt[..., self.active_features]

    bound = self._ucb(ctxt)
    if self.exp_arms is None:
      return self.rng.choice(max_idxs_with_tol(bound))
    else:
      top_idxs = self.exp_arms[max_idxs_with_tol(bound[self.exp_arms])]
      return self.rng.choice(top_idxs)


class LinThompson(AbstractRidge):
  """Thompson sampling for linear contextual bandits.

  Assumes Gaussian prior and likelihood.
  """

  def __init__(self, var: float = 1.0, **kwargs):
    """Initialise a Thompson sampling agent.

    :param var: Variance of the rewards.
    :param regulariser: Prior precision _multiplied by_ `out_var`. Will be
      further multiplied by feature dimension if `scale_regulariser_by_dim`
      is in `kwargs` and set to `True`.
    :param kwargs: Additional keyword arguments passed to
      `AbstractLinContextAgent`.
    """
    self.var = var
    super().__init__(**kwargs)

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    raise NotImplementedError()

  def select_arm(self, ctxt: np.ndarray) -> int:
    """Sample parameter from the posterior, then take argmax."""
    ctxt = ctxt[..., self.active_features]

    sample = self.rng.multivariate_normal(self.weights, self.var * self.covar)
    sample = ctxt @ sample
    if self.exp_arms is None:
      return self.rng.choice(max_idxs_with_tol(sample))
    else:
      top_idxs = self.exp_arms[max_idxs_with_tol(sample[self.exp_arms])]
      return self.rng.choice(top_idxs)


class SquareCB:
  """
  Reduction to online regression as described, e.g., in
   https://proceedings.neurips.cc/paper/2020/file/84c230a5b1bc3495046ef916957c7238-Paper.pdf
   """
  @staticmethod
  def get_abe_long(lr, n_arms):
    def abe_long(preds):
      idx = np.argmax(preds)
      probs = 1 / (n_arms + lr * (preds[idx] - preds))
      probs[idx] += 1 - np.sum(probs)
      return probs
    return abe_long

  @staticmethod
  def get_log_barrier(lr, n_arms):
    raise NotImplementedError()  # TODO

  def __init__(
      self, n_arms, oracle, lr_scale=1e1, dist_type='abe-long', seed=None):
    self.rng = np.random.RandomState(seed)

    self.n_arms = n_arms
    self.oracle = oracle

    self.lr = lr_scale * n_arms  # TODO: matches abe-long; log-barrier unclear
    self.dist_type = dist_type
    if self.dist_type == 'abe-long':
      self.pred2probs = SquareCB.get_abe_long(self.lr, self.n_arms)
    elif self.dist_type == 'log-barrier':
      self.pred2probs = SquareCB.get_log_barrier(self.lr, self.n_arms)
    else:
      raise NotImplementedError(dist_type)

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    preds = self.oracle.get_indices(ctxt)
    return self.pred2probs(preds)

  def select_arm(self, ctxt: np.ndarray) -> int:
    probs = self.pull_probs(ctxt)
    return self.rng.choice(np.arange(self.n_arms), p=probs)

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    self.oracle.update(arm, reward, ctxt)

  def reset(self):
    self.oracle.reset()

  def __repr__(self):
    return f'{self.__class__.__name__} dist_type={self.dist_type} ' \
           f'lr={self.lr} oracle={self.oracle}'


class DelayedAgent:
  """Wrapper for other contextual bandits enabling batch updates.

  Currently just delays the updates but still feeds the data points one-by-one.
  Can be easily modified to feed the whole batches into the update function at
  the same time (currently we have no agents implementing batch updates; batch
  and online updates will be the same for ridge agents though).
  """
  def __init__(self, agent, scheduler):
    self.agent = agent
    self.scheduler = scheduler

    self.buffer = []
    self.step = None
    self.reset()

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    return self.agent.pull_probs(ctxt)

  def select_arm(self, ctxt: np.ndarray) -> int:
    return self.agent.select_arm(ctxt)

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    self.buffer.append((arm, reward, ctxt))
    if self.scheduler(self.step):
      for a, r, c in self.buffer:
        self.agent.update(a, r, c)
      self.buffer.clear()
    self.step += 1

  def reset(self):
    self.agent.reset()
    self.buffer.clear()
    self.step = 1

  def __repr__(self):
    return f'{self.__class__.__name__} agent={self.agent}'


class TwoStage:
  def __init__(self, ranker, nominators):
    self.ranker = ranker
    self.nominators = nominators

    self.reset()

  def pull_probs(self, ctxt: np.ndarray) -> np.ndarray:
    raise NotImplementedError()

  def select_arm(self, ctxt: np.ndarray) -> Tuple[int, Sequence[int]]:
    nom_arms = [nom.select_arm(ctxt) for nom in self.nominators]
    ranker_ctxt = np.array([ctxt[arm] for arm in nom_arms])
    ranker_arm = self.ranker.select_arm(ranker_ctxt)
    return nom_arms[ranker_arm], nom_arms

  def update(self, arm: int, reward: float, ctxt: np.ndarray):
    self.ranker.update(arm, reward, ctxt)
    for nom in self.nominators:
      nom.update(arm, reward, ctxt)

  def reset(self):
    self.ranker.reset()
    for nom in self.nominators:
      nom.reset()

  def __repr__(self):
    return f'{self.__class__.__name__} ranker={self.ranker} ' \
           f'n_nominators={len(self.nominators)}'
