import joblib
from joblib import Parallel, delayed
import numpy as np
import time


# Bandit environments and simulator

class GaussBandit(object):
  """Gaussian bandit."""

  def __init__(self, mu, sigma=1.0):
    self.mu = np.copy(mu)
    self.K = self.mu.size
    self.sigma = sigma

    self.best_arm = np.argmax(self.mu)

    self.randomize()

  def randomize(self):
    # generate random rewards
    self.rt = self.mu + self.sigma * np.random.randn(self.K)

  def reward(self, arm):
    # instantaneous reward of the arm
    return self.rt[arm]

  def regret(self, arm):
    # instantaneous regret of the arm
    return self.rt[self.best_arm] - self.rt[arm]

  def pregret(self, arm):
    # expected regret of the arm
    return self.mu[self.best_arm] - self.mu[arm]

  def print(self):
    return "Gaussian bandit with arms (%s)" % \
      ", ".join("%.3f" % s for s in self.mu)


class BoxBandit(object):
  """Box noise bandit."""

  def __init__(self, mu, sigma=1.0):
    self.mu = np.copy(mu)
    self.K = self.mu.size
    self.sigma = sigma

    self.best_arm = np.argmax(self.mu)

    self.randomize()

  def randomize(self):
    # generate random rewards
    self.rt = self.mu + self.sigma * (2 * (np.random.rand(self.K) > 0.5) - 1)

  def reward(self, arm):
    # instantaneous reward of the arm
    return self.rt[arm]

  def regret(self, arm):
    # instantaneous regret of the arm
    return self.rt[self.best_arm] - self.rt[arm]

  def pregret(self, arm):
    # expected regret of the arm
    return self.mu[self.best_arm] - self.mu[arm]

  def print(self):
    return "Box bandit with arms (%s)" % \
      ", ".join("%.3f" % s for s in self.mu)


class LinBandit(object):
  """Linear bandit."""

  def __init__(self, X, theta, sigma=1.0):
    self.X = np.copy(X)  # K x d matrix of arm features
    self.K = self.X.shape[0]
    self.d = self.X.shape[1]
    self.theta = np.copy(theta)  # model parameter
    self.sigma = sigma  # reward noise

    self.mu = self.X.dot(self.theta)  # mean rewards of all arms
    self.best_arm = np.argmax(self.mu)  # optimal arm

    self.randomize()

  def randomize(self):
    # generate random rewards
    self.rt = self.mu + self.sigma * np.random.randn(self.K)

  def reward(self, arm):
    # instantaneous reward of the arm
    return self.rt[arm]

  def regret(self, arm):
    # instantaneous regret of the arm
    return self.rt[self.best_arm] - self.rt[arm]

  def pregret(self, arm):
    # expected regret of the arm
    return self.mu[self.best_arm] - self.mu[arm]

  def print(self):
    return "Linear bandit: %d dimensions, %d arms" % (self.d, self.K)


def evaluate_one(Alg, params, env, n, period_size=1):
  """One run of a bandit algorithm."""
  alg = Alg(env, n, params)

  regret = np.zeros(n // period_size)
  for t in range(n):
    # generate state
    env.randomize()

    # take action and update agent
    arm = alg.get_arm(t)
    alg.update(t, arm, env.reward(arm))

    # track performance
    regret_at_t = env.regret(arm)
    regret[t // period_size] += regret_at_t

  return regret, alg


def evaluate(Alg, params, env, n=1000, period_size=1, printout=True):
  """Multiple runs of a bandit algorithm."""
  if printout:
    print("Evaluating %s" % Alg.print(), end="")
  start = time.time()

  num_exps = len(env)
  regret = np.zeros((n // period_size, num_exps))
  alg = num_exps * [None]

  output = Parallel(n_jobs=-1)(delayed(evaluate_one)(Alg, params, env[ex], n, period_size)
    for ex in range(num_exps))
  for ex in range(num_exps):
    regret[:, ex] = output[ex][0]
    alg[ex] = output[ex][1]
  if printout:
    print(" %.1f seconds" % (time.time() - start))

  if printout:
    total_regret = regret.sum(axis=0)
    print("Regret: %.2f +/- %.2f (median: %.2f, max: %.2f, min: %.2f)" %
      (total_regret.mean(), total_regret.std() / np.sqrt(num_exps),
      np.median(total_regret), total_regret.max(), total_regret.min()))

  return regret, alg


# Bandit algorithms

class GaussBanditAlg:
  def __init__(self, env, n, params):
    self.env = env  # bandit environment that the agent interacts with
    self.K = self.env.K  # number of arms
    self.n = n  # horizon
    self.mu0 = np.zeros(self.K)  # prior mean of mean arm rewards
    self.sigma0 = np.eye(self.K)  # prior variance of mean arm rewards
    self.sigma = 1.0  # reward noise

    # override default values
    for attr, val in params.items():
      if isinstance(val, np.ndarray):
        setattr(self, attr, np.copy(val))
      else:
        setattr(self, attr, val)

    self.pulls = np.zeros(self.K)  # number of pulls
    self.reward = np.zeros(self.K)  # cumulative reward

  def update(self, t, arm, r):
    self.pulls[arm] += 1
    self.reward[arm] += r


class UCB1(GaussBanditAlg):
  def get_arm(self, t):
    if t < self.K:
      # each arm is initially pulled once
      self.mu = np.zeros(self.K)
      self.mu[t] = 1
    else:
      # UCBs
      t += 1  # time starts at one
      self.mu = self.reward / self.pulls + self.sigma * np.sqrt(2 * np.log(self.n) / self.pulls)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "UCB1"


class BayesUCB(GaussBanditAlg):
  def __init__(self, env, n, params):
    self.delta = 1 / n

    GaussBanditAlg.__init__(self, env, n, params)

  def get_arm(self, t):
    sigma2 = np.square(self.sigma)
    sigma02 = np.square(self.sigma0)
    post_var = 1.0 / (1.0 / sigma02 + self.pulls / sigma2)
    post_mean = post_var * (self.mu0 / sigma02 + self.reward / sigma2)

    # posterior UCBs
    self.mu = post_mean + np.sqrt(2 * np.log(1 / self.delta) * post_var)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "BayesUCB"


class LinBanditAlg:
  def __init__(self, env, n, params):
    self.env = env  # bandit environment that the agent interacts with
    self.K = self.env.K  # number of arms
    self.d = self.env.d  # number of features
    self.n = n  # horizon
    self.theta0 = np.zeros(self.d)  # prior mean of the model parameter
    self.Sigma0 = np.eye(self.d)  # prior covariance of the model parameter
    self.sigma = 1.0  # reward noise

    # override default values
    for attr, val in params.items():
      if isinstance(val, np.ndarray):
        setattr(self, attr, np.copy(val))
      else:
        setattr(self, attr, val)

    # sufficient statistics
    self.Lambda = np.linalg.inv(self.Sigma0)
    self.B = self.Lambda.dot(self.theta0)

  def update(self, t, arm, r):
    # update sufficient statistics
    x = self.env.X[arm, :]
    self.Lambda += np.outer(x, x) / np.square(self.sigma)
    self.B += x * r / np.square(self.sigma)


class LinUCB(LinBanditAlg):
  def __init__(self, env, n, params):
    LinBanditAlg.__init__(self, env, n, params)

    self.cew = self.confidence_ellipsoid_width(n)

  def confidence_ellipsoid_width(self, t):
    # Theorem 2 in Abassi-Yadkori (2011)
    # Improved Algorithms for Linear Stochastic Bandits
    delta = 1 / self.n
    L = np.amax(np.linalg.norm(self.env.X, axis=1))
    Lambda = np.square(self.sigma) * np.linalg.eigvalsh(np.linalg.inv(self.Sigma0)).max()  # V = \sigma^2 (posterior covariance)^{-1}
    R = self.sigma
    S = np.sqrt(self.d)
    width = np.sqrt(Lambda) * S + \
      R * np.sqrt(self.d * np.log((1 + t * np.square(L) / Lambda) / delta))
    return width

  def get_arm(self, t):
    # linear model posterior
    Sigmahat = np.linalg.inv(self.Lambda)
    thetahat = Sigmahat.dot(self.B)

    # UCBs
    invV = Sigmahat / np.square(self.sigma)  # V^{-1} = posterior covariance / \sigma^2
    self.mu = self.env.X.dot(thetahat) + self.cew * \
      np.sqrt(np.einsum("ij,jk,ik->i", self.env.X, Sigmahat, self.env.X))

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LinUCB"


class LinBayesUCB(LinBanditAlg):
  def __init__(self, env, n, params):
    self.delta = 1 / n

    LinBanditAlg.__init__(self, env, n, params)

  def get_arm(self, t):
    # linear model posterior
    Sigmahat = np.linalg.inv(self.Lambda)
    thetahat = Sigmahat.dot(self.B)

    # posterior UCBs
    self.mu = self.env.X.dot(thetahat) + np.sqrt(2 * np.log(self.K / self.delta) *
      np.einsum("ij,jk,ik->i", self.env.X, Sigmahat, self.env.X))

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LinBayesUCB"