from scipy import optimize
import numpy as np
from  tqdm import tqdm
import random as random
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import pickle
import sys

class User_LB:
  '''
  User for DBGD Lower Bound Simulation
  '''
  def __init__(self, theta, method, n_iter):
    self.theta = theta
    self.consumed_action = []
    self.seen_action = []
    self.method = method
    self.n_iter = n_iter

  def reset(self):
    self.consumed_action = []
    self.seen_action = []

  def update(self, action):  # update consumed actions
    self.consumed_action.append(action)

  def update2(self, action): # update at'
    self.seen_action.append(action)

  def utility(self, a):
    return np.dot(self.theta, a)

  def link(self, x):
    return 0.5 + 0.5*x

  def compare(self, left, right):
    prob = self.link(self.utility(left) - self.utility(right))

    if random.uniform(0, 1) < prob:
      return left,  0
    else:
      return right, 1

class User:
  def __init__(self, theta, V0, corrupted, rho, method, n_iter, corruption_mode = "LU"):
    self.theta = theta
    self.consumed_action = []
    self.seen_action = []
    self.V0 = V0
    self.rho = rho
    self.corrupted = corrupted
    self.method = method
    self.n_iter = n_iter
    self.corruption_times = n_iter**(0.5 + self.rho)
    self.corruption_mode = corruption_mode

  def reset(self):
    self.consumed_action = []
    self.seen_action = []

  def update(self, action):  # update consumed actions
    self.consumed_action.append(action)

  def update2(self, action): # update at'
    self.seen_action.append(action)

  def cost(self, a):
    return np.dot(self.theta.reshape(-1), a.reshape(-1)) + np.linalg.norm(a)**2/2

  def utility(self, a):
    return -(np.dot(self.theta.reshape(-1), a.reshape(-1)) + np.linalg.norm(a)**2/2)

  def noisy_cost_diff(self, x, y):
    if self.cost(x) > self.cost(y):
      return self.cost(x) - self.cost(y) - self.corruption(x, y)
    else:
      return self.cost(x) - self.cost(y) + self.corruption(x, y)

  def noisy_utility_diff(self, x, y):
    if self.utility(x) > self.utility(y):
      return self.utility(x) - self.utility(y) - self.corruption(x, y)
    else:
      return self.utility(x) - self.utility(y) + self.corruption(x, y)

  def link(self, x):
    '''
    Logistic Link Function
    '''
    return 1/(1 + np.exp(-x))

  def corruption(self, x, y):
    '''
    Learning User Corruption, Mathmatical Details see Appendix A
    '''
    A = np.concatenate(self.consumed_action, axis = 1)
    B = np.concatenate(self.seen_action, axis = 1)
    C = A-B
    V = self.V0 + np.matmul(C, C.T)
    return 10*((np.matmul(np.matmul((x-y).T, np.linalg.inv(V)), (x-y)))**(0.5 - self.rho)).reshape(-1)

  def update_corruption(self):
      self.corruption_times = self.corruption_times-1

  def compare(self, left, right):
    if self.corruption_mode == "G" and self.corruption_times > 0:
      self.update_corruption()
      if self.method == "DBGD":
        if self.utility(left) < self.utility(right):
          return left, 0
        else:
          return right, 1
      else:
        if self.cost(left) > self.cost(right):
          return left, 0
        else:
          return right, 1
    elif self.corruption_mode == "G":
      if self.method == "DBGD":
        prob = self.link(self.utility(left) - self.utility(right))
      else:
        prob = self.link(self.cost(left) - self.cost(right))
      
      if random.uniform(0, 1) < prob:
        return left,  0
      else:
        return right, 1

    else:
      if self.method == "DBGD":
        prob = self.link(self.noisy_utility_diff(left, right))
      else:
        prob = self.link(self.noisy_cost_diff(left, right))

      if random.uniform(0, 1) < prob:
        return left,  0
      else:
        return right, 1