import numpy as np
from tqdm import tqdm
from scipy.stats import linregress
#import matplotlib.pyplot as plt
from scipy.optimize import minimize, fsolve
import random

class MultiAgentMirrorDescent(object):
  def __init__(self, n, r, p, d, x_init, delta, gamma, competition, B, sanity_check=False):
    self.n = n # number of players
    self.r = r # R radius ball
    self.p = p # centered at p
    self.d = d # action dimension
    self.x = np.array(x_init)
    self.delta = delta
    self.gamma = gamma
    self.competition = competition
    self.B = B
    self.sanity_check = sanity_check

  def reset(self):
    x_init = np.random.rand(self.n)*self.B
    self.x = x_init

  def set_order(self, gamma_order, delta_order):
    self.delta_order = delta_order
    self.gamma_order = gamma_order

  def projection(self, y):
    return np.clip(self.x+y, 0, self.B)

  def update_delta(self, t):
    return self.delta/((t+1)**self.delta_order) #((1 - self.competition.rho)/2)

  def update_gamma(self, t):
    return self.gamma/((t+1)**self.gamma_order) #(3*(1 - self.competition.rho)/2))

  def simulation(self, n_iter, corrupted=False):

    NE = self.competition.ne()
    MNE = self.competition.ne(method='pgd', delta=0.2*self.B, i=0, corruption=True)
    convergence = np.zeros(n_iter)  # Track the distance to true nash equilibrium
    divergence = np.zeros(n_iter)  # Track the distance to manipulated nash equilibrium

    for t in range(n_iter):
      if t % 1000 == 0:
        print(f"current iteration is {t+1}")
      u = np.array(np.random.choice([-1, 1], size=self.n))
      w = u - (self.x-self.p)/self.r
      delta = self.update_delta(t)

      x_hat = self.x + delta*w
      convergence[t] = np.linalg.norm(x_hat-NE)**2
      divergence[t] = np.linalg.norm(x_hat-MNE)**2

      # if corrupted:
      #   self.competition.update_iter(t+1)

      # print('self.competition.corruption_budget', self.competition.corruption_budget)
      # print('self.competition.rho', self.competition.rho)
      if self.competition.corruption_budget > 0 and self.competition.rho > 0:
        utility = self.competition.attacked_utility(x = x_hat, delta = 0.2 * self.B)
      elif self.sanity_check:
        utility = self.competition.attacked_utility(x = x_hat, delta = 0.2 * self.B)
      else:
        utility = self.competition.utility(x_hat, corrupted=corrupted)

      v_hat = self.d*utility*u/delta

      gamma = self.update_gamma(t)

      self.x = self.projection(gamma*v_hat)

    # print(x_hat, NE)
    self.reset()
    return convergence, divergence

class Simulation(object):
  def __init__(self, game_instance, algorithm, corruption_type=None):

    self.competition = game_instance
    self.algorithm = algorithm

  def run_simulation(self, n_iter = 50000, n_repeats=10, corrupted=False, return_full_data=False):
    #self.algorithm.set_order(gamma_order=gamma_order, delta_order=delta_order)
    # set model parameters
    all_convergence = np.zeros((n_repeats, n_iter))
    all_divergence = np.zeros((n_repeats, n_iter))
    if self.algorithm.sanity_check:
      all_corruption = np.zeros((n_repeats, n_iter))

    print(all_divergence.shape)
    ave_convergence = np.zeros(n_iter)
    ave_divergence = np.zeros(n_iter)

    for i in tqdm(range(n_repeats)):
      # Notice that we need to put the random seed inside to improve randomness
      # if n_repeats > 1:
      #   np.random.seed(i)  
      # else:
      #   random_number = random.randint(0, 1000)
      #   np.random.seed(random_number) 

      all_convergence[i], all_divergence[i] = self.algorithm.simulation(n_iter = n_iter, corrupted=corrupted)
      #print(f"In total we used {np.sum(self.competition.corruption_budget_used)} corruption and orginal corruption amount is {n_iter ** self.competition.rho}")
      if self.algorithm.sanity_check:
        all_corruption[i] = self.competition.corruption_budget_used
      
      # Clean and updates:
      self.competition.corruption_budget_used = []
      self.competition.corruption_budget = n_iter ** self.competition.rho ** self.competition.corruption_constant

    ave_convergence = np.mean(all_convergence, axis=0)
    ave_divergence = np.mean(all_divergence, axis=0)
    if self.algorithm.sanity_check:
      ave_corruption = np.mean(all_corruption, axis = 0)

    if return_full_data:
      if self.algorithm.sanity_check:
        return ave_convergence, all_convergence, ave_divergence, all_divergence, ave_corruption, all_corruption
      else:
        return ave_convergence, all_convergence, ave_divergence, all_divergence
      print(all_divergence.shape)
    else:
      return ave_convergence, ave_divergence

class SelfConcordantBarrier(object):
  def __init__(self, n, d, x_init, eta, beta, lam, concordant, grad_f, hessian_f, competition, B, sanity_check=False):
    self.n = n # number of players
    self.d = d # action dimension
    self.x = np.array(x_init)
    self.eta = eta
    self.beta = beta
    self.lam = lam
    self.concordant = concordant
    self.grad_f = grad_f
    self.hessian_f = hessian_f
    self.competition = competition
    self.B = B  # action space
    self.sanity_check = sanity_check
  
  def reset(self):
    x_init = np.random.rand(self.n)*self.B
    self.x = x_init
  
  def set_order(self, eta_order):
    self.eta_order = eta_order

  def projection(self, v_hat, eta, t):
    """
    New projection matrix
    """
    new_x = []
    for i in range(len(self.x)):
      def f(x):
        # print(f"inside function f i is {i}, {v_hat[i]}, {self.x[i]}, {eta}")
        # print(self.bregmanDivergence(0.5, self.x[i]))
        return eta*np.dot(v_hat[i], self.x[i]-x) + self.bregmanDivergence(x, self.x[i])+(x-self.x[i])**2*eta*self.beta*(t+1)/2/self.lam

      bounds = [(0.000001, 0.99999*self.B)]
      result = minimize(f, self.x[i], bounds = bounds)["x"][0]

      # print(f"result is {result}")
      new_x.append(result)

    return new_x

  def bregmanDivergence(self, x, y):
    return self.concordant(x)-self.concordant(y)-np.dot(self.grad_f(y), x-y)


  def computeAInverse(self, eta, t):
    '''
    Compute (A^ti)
    '''
    # Note this only works for d = 1
    A_square = []
    for i in self.x:
      val = self.hessian_f(i) + self.d*eta*self.beta*(t+2)/self.lam
      # print(val)
      A_square.append(np.sqrt(val))
    return A_square

  def computeA(self, eta, t):
    '''
    matrix computation
    '''
    # return np.linalg.inv(self.computeAInverse(eta, t))
    A_inv = []
    for i in self.computeAInverse(eta, t):
      # print(f"computA {1/i}")
      A_inv.append(1/i)
    return A_inv

  def update_eta(self, t):
    return self.eta/((t+1)**self.eta_order)/self.d

  def simulation(self, n_iter, corrupted=False):

    NE = self.competition.ne()
    MNE = self.competition.ne(method='pgd', delta=0.2 * self.B, i=0, corruption=True)
    convergence = np.zeros(n_iter)  # Track the distance to true nash equilibrium
    divergence = np.zeros(n_iter)  # Track the distance to manipulated nash equilibrium

    for t in range(n_iter):
      if t % 1000 == 0:
        print(f"current iteration is {t+1}")
      u = np.array(np.random.choice([-1, 1], size=self.n))
      # u = np.array(np.random.choice([-self.B, self.B], size=self.n))
      eta = self.update_eta(t)
      A = self.computeA(eta, t)
      #print(f"A is {len(A)}, current action is {len(self.x)}, Au is {len(A*u)}")
      x_hat = self.x + A*u  # elementwise product
      convergence[t] = np.linalg.norm(x_hat-NE)**2
      divergence[t] = np.linalg.norm(x_hat-MNE)**2

      if self.competition.corruption_budget > 0 and self.competition.rho > 0:
        utility = self.competition.attacked_utility(x = x_hat, delta = 0.2 * self.B)
      elif self.sanity_check:
        utility = self.competition.attacked_utility(x = x_hat, delta = 0.2 * self.B)
      else:
        utility = self.competition.utility(x_hat, corrupted=corrupted)

      v_hat = self.d*utility*self.computeAInverse(eta, t)* u  # elementwise product
      self.x = self.projection(v_hat, eta, t)

    self.reset()
    return convergence, divergence