import numpy as np

class CournotCompetition(object):

  def __init__(self, a, b, c, n, rho=0, iter=0, sigma=0, corruption = False, corruption_constant=1, n_iter=100000):
    self.a = a  # a single number
    self.b = b  # a single number
    self.c = c  # a cost vector
    self.n = n  # number of players
    self.corrupted = corruption
    self.rho = rho
    self.iter = iter
    self.sigma = sigma
    self.corruption_constant = corruption_constant
    self.corruption_budget = n_iter ** (self.rho) * self.corruption_constant
    self.corruption_budget_used = []

  def update_iter(self, iter):
    self.iter = iter

  def utility(self, x, corrupted=False):
    # x is the action profile
    p = self.a - self.b * np.sum(x)
    u = p*x - np.multiply([self.c] * self.n, x)
    #z = np.random.normal(0, 0.001, size = (self.n, ))
    if corrupted:
      return u + self.corruption(self.rho, self.iter, self.sigma)
    else:
      return u


  def attacked_utility(self, x, delta, i=0):
    # return the attacked utility for player-i

    # def u(xi):
    #   # negative utility of player-i when fixing other players
    #   new_x = x+0.0
    #   new_x[i] = xi[0]
    #   return -self.utility(new_x)[i]

    # result = minimize(u, x0=0.0)
    # BR_x_i = result.x[0]

    # explicitly compute the best response mapping
    S_i = np.sum(x)-x[i]
    BR_x_i = (self.a-self.c-self.b*S_i)/2/self.b

    x_delta = x+0.0
    x_delta[i] += delta

    x1, x2 = x+0.0, x+0.0
    x1[i], x2[i] = BR_x_i, BR_x_i-delta

    attacked_utility = self.utility(x_delta) - self.utility(x1) + self.utility(x2)
    self.corruption_budget = self.corruption_budget - abs(attacked_utility[i] - self.utility(x)[i])
    self.corruption_budget_used.append(abs(attacked_utility[i] - self.utility(x)[i]))
    # print(self.corruption_budget, attacked_utility[i], self.utility(x)[i])

    #ui = self.utility(x)+delta*(self.a-self.c-self.b*np.sum(x)-self.b*x[i]-2*self.b*delta)
    return attacked_utility


  def grad(self, x, delta=0.0, i=0, corruption=False):
    s = np.sum(x)
    g = self.a - self.b * s - self.b * x - self.c
    if corruption:
      g[i] = g[i] - 2*delta*self.b
    return g


  def corruption(self, rho, iter, sigma):
    # random utility
    direction = np.array(np.random.choice([-1, 1], size=self.n))
    b = direction * iter ** (rho - 1)
    return np.random.normal(b, sigma)


  def ne(self, method='exact', delta=0.0, i=0, corruption=False):
    # compute nash equilibrium of Cournot Competition
    # Notice that this only work for sysmetry case, basically is when all c are equal
    if method == 'exact':
      return np.array([(self.a - self.c)/(self.n + 1)/self.b] * self.n)
    else:
      # compute NE through projected gradient descent (PGD)
      x = np.ones(self.n)
      T, lr = 2000, 0.1
      for _ in range(T):
          g = self.grad(x, delta=delta, i=i, corruption=corruption)
          x = np.clip(x + g * lr, 0, np.inf)
          if np.linalg.norm(x*g) < 1e-4:
              return x #, g, self.random_utility(x)

      print('\nWarning: exit before converging. Residual: ', np.linalg.norm(x*g))
      return x #, g, self.random_utility(x)