from scipy import optimize
import numpy as np
from  tqdm import tqdm
import random as random
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import pickle
import sys

class DBGD_LB:
    '''
    Implementation for DBGD Lower Bound Simulation
    '''
    def __init__(self, feature_dim, delta, gamma, user):
      self.d = feature_dim
      self.delta = delta
      self.gamma = gamma
      self.action = []
      self.a0 = np.zeros((self.d, 1))
      self.t = 0
      self.regret = []
      self.user = user
      self.deviation = []

    def projection(self, a):
      fun = lambda x: (x[0] - a[0])**2 + (x[1] - a[1])**2
      cons = ({'type': 'ineq', 'fun': lambda x: 0.25-0.5*x[0]-x[1]})
      bnds = ((0, 1), (0, 1))
      res = optimize.minimize(fun, (0, 0), method = 'SLSQP', bounds = bnds, constraints=cons, tol = 1e-10).x
      res = res.reshape(self.d, 1)
      return res

    def check_condition(self,a):
      if (0.25 - 0.5*a[0] - a[1]) < 0 or (a[0] > 0.5) or (a[1] > 0.25) or (a[0] < 0) or (a[1] < 0):
        return True

    def simulate(self, T, user):
      left = self.a0
      for t in tqdm(range(T)):
        self.t = t+1
        self.action = left
        user.update(left)
        # sample a unit vector ut uniformly at random
        ut = np.random.randn(self.d, 1)
        ut /= np.linalg.norm(ut, axis=0)
        right = left + self.delta*ut
        if self.check_condition(right):
          right = self.projection(right)

        # if user.utility(right) < 0:
        #   print("WRONG !!!!!!!!!!!")

        user.update2(right)

        # update regret
        self.regret.append(0.5-user.utility(self.action)-user.utility(right))
        self.deviation.append(np.linalg.norm(self.action - np.array([0.5, 0]).reshape(2,-1)))

        # Verbose
        if t % 1000 == 0:
          print(0.5-user.utility(self.action)-user.utility(right))
          print(np.linalg.norm(self.action - np.array([0.5, 0]).reshape(2,-1)))

        _, score = user.compare(left, right)

        if score == 1:
          left = left + self.gamma*ut
          if self.check_condition(left):
            left = self.projection(left)

class NCSMD:
  def __init__(self, feature_dim, lmbd, radius, mu, eta, user):
    self.d = feature_dim
    self.lbd = lmbd
    self.radius = radius
    self.mu = mu
    self.eta = eta
    self.action = []
    self.a0 = np.zeros((self.d, 1))
    self.t = 0
    self.gradient = []
    self.regret = []
    self.user = user
    self.deviation = []

  def sqrt_hessian(self,t,a):
    norm_a = np.linalg.norm(a)  # compute the norm of a vector
    constant = 2/(self.radius**2 - norm_a**2) + self.lbd * self.eta * t + 2*self.mu
    mat = 4*np.matmul(a, a.T)/(self.radius**2 - norm_a**2)**2 + constant*np.eye(self.d)
    eigenvalues, eigenvectors = np.linalg.eig(mat)
    eigenvalues, eigenvectors = eigenvalues.real, eigenvectors.real
    return np.multiply(np.multiply(eigenvectors, np.sqrt(eigenvalues)), eigenvectors.T), np.multiply(np.multiply(eigenvectors, 1/np.sqrt(eigenvalues)), eigenvectors.T)

  def fun(self, x):
    bt = 2/(self.radius**2 - np.linalg.norm(self.action.reshape(-1))**2)*self.action.reshape(-1) + (self.lbd*self.eta*self.t + 2*self.mu)*self.action.reshape(-1) - self.eta*self.gradient.reshape(-1)
    return(np.linalg.norm((2/(self.radius**2 - np.linalg.norm(x)**2) + self.lbd * self.eta * self.t + 2*self.mu)*x-bt))

  def simulate(self, T, user):
    left = self.a0
    for t in tqdm(range(T)):
      self.t = t+1
      self.action = left
      user.update(left)

      # sample a unit vector ut uniformly at random
      ut = np.random.randn(self.d, 1)
      ut /= np.linalg.norm(ut, axis=0)
      sqrt_hessian, inv_sqrt_hessian = self.sqrt_hessian(self.t, self.action)
      right = left + np.matmul(inv_sqrt_hessian, ut)
      user.update2(right)
      # print(right.shape)

      # update regret
      # print(user.cost(self.action) + 100 + user.cost(right))
      # self.regret.append(user.cost(self.action) + 100 + user.cost(right))
      self.regret.append(user.cost(self.action) + self.radius**2 + user.cost(right))
      self.deviation.append(np.linalg.norm(self.action + user.theta))

      # Verbose
      if t % 1000 == 0:
        # print(f"current left is {left}")
        print(np.linalg.norm(self.action + user.theta))
        print(user.cost(self.action) + 100 + user.cost(right))

      _, score = user.compare(left, right)

      if score == 1:
        gt = score*self.d*np.matmul(sqrt_hessian,ut)
        self.gradient = gt
        left = optimize.minimize(self.fun, self.action.reshape(-1), method='Nelder-Mead', tol=1e-6).x
        left = left.reshape(self.d,1)

class DBGD:
    def __init__(self, feature_dim, radius, delta, gamma, user):
      self.d = feature_dim
      self.delta = delta
      self.gamma = gamma
      self.radius = radius
      self.action = []
      self.a0 = np.zeros((self.d, 1))
      self.t = 0
      self.regret = []
      self.user = user
      self.deviation = []

    def simulate(self, T, user):
      left = self.a0
      for t in tqdm(range(T)):
        self.t = t+1
        self.action = left
        user.update(left)

        # sample a unit vector ut uniformly at random
        ut = np.random.randn(self.d, 1)
        ut /= np.linalg.norm(ut, axis=0)
        right = left + self.delta*ut

        # Projection
        if np.linalg.norm(right) > self.radius:
          right = right / np.linalg.norm(right, axis=0)
          right = self.radius * right

        user.update2(right)

        # update regret
        # print(100-user.utility(self.action)-user.utility(right))
        # self.regret.append(100-user.utility(self.action)-user.utility(right))
        self.regret.append(self.radius**2-user.utility(self.action)-user.utility(right))
        self.deviation.append(np.linalg.norm(self.action + user.theta))

        # Verbose
        if t % 1000 == 0:
          print(np.linalg.norm(self.action + user.theta))
          print(100-user.utility(self.action)-user.utility(right))

        _, score = user.compare(left, right)

        if score == 1:
          left = left + self.gamma*ut
          if np.linalg.norm(left) > self.radius:
            left = left / np.linalg.norm(left, axis=0)
            left = self.radius * left

class SBM:
    def __init__(self, feature_dim, radius, alpha, delta, nu):
      self.d = feature_dim
      self.radius = radius
      self.alpha = alpha
      self.delta = delta
      self.nu = nu
      self.ut = None
      self.left = np.zeros((self.d, 1))
      self.regret = []
    
    def reset(self):
      self.d = self.d
      self.alpha = self.alpha
      self.delta = self.delta
      self.nu = self.nu
      self.ut = None
      self.left = np.zeros((self.d, 1))
      self.regret = []
    
    def genTheta(self):
      theta = np.random.randn(self.d, 1)
      theta /= np.linalg.norm(theta, axis=0)
      theta = self.radius * theta
      return theta

    def recommend(self):
      ut = np.random.randn(self.d, 1)
      ut /= np.linalg.norm(ut, axis=0)
      self.ut = ut  # sample unit vector
      
      right = self.left + self.delta*self.ut
      if np.linalg.norm(right) > self.radius:
          right = right / np.linalg.norm(right, axis=0)
          right = self.radius*right
      return right
    
    def update(self, cost):
      self.left = self.left - self.nu*cost*self.ut
      if np.linalg.norm(self.left) > self.radius*(1-self.alpha):
          self.left = self.left / np.linalg.norm(self.left, axis=0)
          self.left = self.radius*(1-self.alpha)*self.left
    
    def simulate(self, n_iter, user):
      for t in range(n_iter):
        right = self.recommend()
        self.update(-user.utility(right))
        self.regret.append(self.radius**2 - 2*user.utility(self.left))

class Doubler:
    def __init__(self, feature_dim, user, radius, alpha, delta, nu, corruption_mode = "LU"):
      self.d = feature_dim
      self.user = user
      self.radius = radius
      self.alpha = alpha
      self.delta = delta
      self.nu = nu
      self.sbm = SBM(feature_dim=self.d, radius = self.radius, alpha=self.alpha, delta=self.delta, nu=self.nu)
      self.corruption_mode = corruption_mode
      self.L = []

    def reset(self):
      self.sbm = SBM(feature_dim=self.d, radius = self.radius, alpha=self.alpha, delta=self.delta, nu=self.nu)
      self.L = []

    def recommend(self):
      return self.sbm.recommend()

    def update(self, cost):
        self.sbm.update(cost)

    def simulate_lift(self, T=10000):
        if self.corruption_mode == "G":
           self.corruption_count = T**(0.5 + self.user.rho)
        self.sbm = SBM(feature_dim=self.d, radius = self.radius, alpha=self.alpha, delta=self.delta, nu=self.nu)
        self.regret = []
        L = self.sbm.genTheta()
        # print(L)
        self.L.append(L)
        t = 0
        n = int(np.log(T)/np.log(2))+1
        for i in range(n):
            self.sbm.reset()
            next_L = 0
            num_right_wins = 0
            for j in range(2**i):
                if t >= T:
                    return
                left = L
                right = self.recommend()
                self.user.update(left)
                self.user.update2(right)
                
                left_s, right_s = self.user.utility(left), self.user.utility(right)

                if self.user.corrupted:
                    if self.corruption_mode == "LU":
                        if right_s > left_s:
                            right_p = self.user.link(right_s - left_s - self.user.corruption(left, right))
                        else:
                            right_p = self.user.link(right_s - left_s + self.user.corruption(left, right))
                    elif self.corruption_mode == "G" and self.corruption_count > 0:
                        if right_s > left_s:
                            right_p = 0
                            self.corruption_count = self.corruption_count - 1
                        else:
                            right_p = 1
                            self.corruption_count = self.corruption_count - 1
                    else:
                        right_p = self.user.link(right_s - left_s) 
                else:
                    right_p = self.user.link(right_s - left_s)
                
                if np.random.rand() < right_p:
                    score = 1.0
                else:
                    score = -1.0
                
 
                self.update(-self.radius * score)
                self.regret.append(self.sbm.radius**2 - left_s - right_s)

                if score == 1:
                  next_L += right
                else:
                  next_L += left

                t += 1
                if t%1000 == 0:
                  print(f"current regret is {self.sbm.radius**2 - left_s - right_s}")

            L = next_L / (2**i)
            if np.linalg.norm(L) > self.sbm.radius: 
              L = L / np.linalg.norm(L) * self.sbm.radius
            self.L.append(L)


class Sparring:
    def __init__(self, feature_dim, user, radius, alpha, delta, nu, corruption_mode="LU"):
        self.d = feature_dim
        self.user = user
        self.radius = radius
        self.alpha = alpha
        self.delta = delta
        self.nu = nu
        self.left_sbm = SBM(feature_dim=self.d, radius = self.radius, alpha = self.alpha, delta = self.delta, nu = self.nu)
        self.right_sbm = SBM(feature_dim=self.d, radius = self.radius, alpha = self.alpha, delta = self.delta, nu = self.nu)
        self.corruption_mode = corruption_mode

    def reset(self):
        self.left_sbm = SBM(feature_dim=self.d, radius = self.radius, alpha = self.alpha, delta = self.delta, nu = self.nu)
        self.right_sbm = SBM(feature_dim=self.d, radius = self.radius, alpha = self.alpha, delta = self.delta, nu = self.nu)

    def recommend(self):
        left_arm = self.left_sbm.recommend()
        right_arm = self.right_sbm.recommend()
        return left_arm, right_arm

    def update(self, left, right, costs):
        self.left_sbm.update(costs[0])
        self.right_sbm.update(costs[1])

    def simulate(self, T=100):
        if self.corruption_mode == "G":
           self.corruption_count = T**(0.5 + self.user.rho)
        
        self.regret = []
        for i in range(T):
            left, right = self.recommend()
            left_s, right_s = self.user.utility(left), self.user.utility(right)
            self.user.update(left)
            self.user.update2(right)

            if self.user.corrupted:
              if self.corruption_mode == "LU":
                if right_s > left_s:
                    right_p = self.user.link(right_s - left_s - self.user.corruption(left, right))
                else:
                    right_p = self.user.link(right_s - left_s + self.user.corruption(left, right))
              elif self.corruption_mode == "G" and self.corruption_count > 0:
                 if right_s > left_s:
                    right_p = 0
                    self.corruption_count = self.corruption_count - 1
                 else:
                    right_p = 1
                    self.corruption_count = self.corruption_count - 1
              else:
                 right_p = self.user.link(right_s - left_s) 
            else:
              right_p = self.user.link(right_s - left_s)

            if np.random.rand() < right_p:
              self.update(left, right, [self.left_sbm.radius, -self.left_sbm.radius])
            else:
              self.update(left, right, [-self.left_sbm.radius, self.left_sbm.radius])

            self.regret.append(self.left_sbm.radius**2-left_s - right_s)