import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from scipy.integrate import odeint, ode, quad, trapz
from scipy import optimize
from scipy.spatial import distance
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
plt.style.use('seaborn')
from PIL import Image
from scipy.stats import entropy
from IPython.display import HTML
from tqdm import tqdm
import math
import scipy.sparse as sparse
import scipy.stats as stats

def OGDAupdate(x_current, x_prev, y_current, y_prev, eta, G, project=False):
    '''
    Update function for Optimistic gradient descent ascent.
    
        Parameters:
            x_current, y_current (array): current strategy vectors for both players
            x_prev, y_prev (array): previous strategy vectors for both players
            eta (float): stepsize
            G (array): bilinear game matrix
            project (bool): projection back to simplex (does not work since we need some algorithm to project to treeplex. Potential future work.)
    
        Returns:
            x_new, y_new (tuple): new strategy vectors for both players
    '''
    x_new = x_current - 2*eta*(G@y_current) + eta*(G@y_prev)
    y_new = y_current + 2*eta*(G.T@x_current) - eta*(G.T@x_prev)
    if project:
        return (projection_simplex_bisection(x_new), projection_simplex_bisection(y_new))
    else:
        return (x_new, y_new)
    
def GDAupdate(x_current, y_current, eta, G, project=False):
    '''
    Update function for vanilla gradient descent ascent
        
        Parameters:
            x_current, y_current (array): current strategy vectors for both players
            eta (float): stepsize
            G (array): bilinear game matrix
            project (bool): projection back to simplex
    
        Returns:
            x_new, y_new (tuple): new strategy vectors for both players
    '''
    x_new = x_current - eta*(G@y_current)
    y_new = y_current + eta*(G.T@x_current)
    if project:
        return (projection_simplex_bisection(x_new), projection_simplex_bisection(y_new))
    else:
        return (x_new, y_new)

def runGDA2Player(G, x, y, numsteps, eta, project=False, optimistic=True, exponent=None):
    '''
    Function to run OGDA/GDA for 2 players.
        
        Parameters:
            G (array): bilinear payoff matrix
            x (array): initial conditions for x player
            y (array): initial conditions for y player
            numsteps (int): number of iterations
            eta (float): stepsize
            project (bool): projection back to simplex
            optimistic (bool): use optimistic update rule if True
            exponent (float): if not none, uses decreasing stepsize 1/t^(exponent)
        
        Returns:
           data (dict): dictionary of player trajectories    
    '''
    x_current = x[1]
    y_current = y[1]
    x_prev = x[0]
    y_prev = y[0]
    
    for i in tqdm(range(numsteps)):
        if optimistic:
            x_new, y_new = OGDAupdate(x_current, x_prev, y_current, y_prev, eta, G, project=project)
        else:
            if exponent is not None:
                eta  = 1/((i+1)**exponent)
            x_new, y_new = GDAupdate(x_current, y_current, eta, G, project=project)
        
        x = np.vstack((x, x_new))
        y = np.vstack((y, y_new))
        
        x_prev, y_prev = x_current, y_current
        x_current, y_current = x_new, y_new
    
    print('Time average x:', np.average(x, axis=0))
    print('Time average y:', np.average(y, axis=0))

    data = {}
    data['x'] = x
    data['y'] = y
    return data

def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000):
    '''
    Project vector to the simplex
    '''
    func = lambda x: np.sum(np.maximum(v - x, 0)) - z
    lower = np.min(v) - z / len(v)
    upper = np.max(v)

    for it in range(max_iter):
        midpoint = (upper + lower) / 2.0
        value = func(midpoint)
        if abs(value) <= tau:
            break
        if value <= 0:
            upper = midpoint
        else:
            lower = midpoint

    return np.maximum(v - midpoint, 0)

def OGDAupdateNPlayer(x_current, x_prev, P_current, P_prev, eta, project=False):
    '''
    Update function for Optimistic gradient descent ascent.
    
        Parameters:
            x_current (array): current strategy vectors for a player
            x_prev (array): previous strategy vectors for a player
            P_current (array): current total utility across all games the player is involved in
            P_prev (array): previous total utility across all games the player is involved in
            eta (float): stepsize
            project (bool): projection back to simplex (does not work since we need some algorithm to project to treeplex. Potential future work.)
    
        Returns:
            x_new (tuple): new strategy vectors for the player
    '''
    x_new = x_current - 2*eta*(P_current) + eta*(P_prev)
    if project:
        return (projection_simplex_bisection(x_new))
    else:
        return (x_new)
    
def runGDANPlayer(G, vals, numsteps, eta, graph, N, optimistic=True, _print=True):
    '''
    Function to run OGDA/GDA for N players.
        
        Parameters:
            G (array): list of bilinear payoff matrices
            vals (array): initial conditions for all players
            numsteps (int): number of iterations
            eta (float): stepsize
            graph (array): graph representing the interactions between agents
            optimistic (bool): use optimistic update rule if True
            _print (bool): prints time average strategies if True
        
        Returns:
           data (dict): dictionary of player trajectories and time average values   
    '''
    return_data = np.zeros((N, numsteps, G[0].shape[0]))
    return_data_y = np.zeros((N, numsteps, G[0].shape[1]))
    vals_current = vals[:,1]
    return_data[:,1] = vals_current
    vals_prev = vals[:,0]
    return_data[:,0] = vals_prev

    for k in tqdm(range(1,numsteps-1)):
        for i in range(N):
            x_current = return_data[i][k]
            x_prev = return_data[i][k-1]

            opponents = graph[i]
            P_current = np.sum(np.array([(G[j]*graph[i][j])@return_data[j][k] if graph[i][j] >0 else (G[j].T*graph[i][j])@return_data[j][k] for j in range(len(opponents))]), axis=0)
            P_prev = np.sum(np.array([(G[j]*graph[i][j])@return_data[j][k-1] if graph[i][j] >0 else (G[j].T*graph[i][j])@return_data[j][k-1]for j in range(len(opponents))]), axis=0)
        
            x_new = OGDAupdateNPlayer(x_current, x_prev, P_current, P_prev, eta)
            return_data[i][k+1] = x_new 
    averages = np.average(return_data, axis=1)
    if _print:
        print('Time average values:', averages)
    
    data = {}
    data['vals'] = return_data
    data['timeavg'] = averages
    return data

def PlotDist(data, players, nash=None):
    '''
    Function to plot the log of the distance values between the players' strategies and the Nash.
    
        Parameters:
            data (array): array of strategy vectors for each player
            players (int): number of players
            nash (array): nash values for each player
    '''
    vals = []
    labels = []
    for i in tqdm(range(players)):
        if nash is None:
            _nash = data[i][-1]
        else:
            _nash = nash[i]
        dist = np.log([np.linalg.norm(data[i][j] - _nash) + 0.0000001 for j in range(len(data[i]))])
        vals.append(dist)
        labels.append('Player ' + str(i+1))
    plt.figure(figsize=(12, 9), dpi=100)
    plt.xlabel('Time', fontsize=15)
    plt.ylabel('$\log(dist^2(x_t, \mathcal{X}^*))$', fontsize=15)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.plot(np.array(vals).T)
    plt.legend(labels, loc="best", fontsize=15)
    plt.show()
    
    return(vals)

def GetRandGame(dim, seed, density=0.2):
    '''
    Function to get random sparse payoff matrix.
    '''
    rvs = stats.norm().rvs
    np.random.seed(seed)
    game = sparse.random(dim, dim, density, data_rvs=rvs).A
    game_symm = (game + game.T)/2
    return game_symm

def RandomGameSimulation(num_simulations, dimension, graph, players, numsteps, eta, rand=True, seeds=None):
    '''
    Function to run simulations on random extensive form games in sequence form.
    
        Parameters:
            num_simulations (int): number of random simulatinos to run
            dimension (int): dimension of the game matrix
            graph (array): graph of interactions between players
            players (int): number of players
            numsteps (int): number of iterations per simulation
            eta (int/array): stepsize for all games/array of stepsizes for each game
            rand (bool): simulates randomly generated games if True
            seeds (array): simulates using seeds from a predetermined array if not None
    '''
    dists=[]
    data_all=[]
    for i in range(num_simulations):
        if rand:
            game = GetRandGame(dimension, seed=i)
        else:
            game = GetRandGame(dimension, seed=seeds[i])
        game_list = [game for j in range(players)]
        vals = np.random.rand(players, 1, dimension)
        for k in range(vals.shape[0]):
            vals[k] = vals[k]/(np.sum(vals[k]))
        vals = np.repeat(vals, 2, axis=1)
        if type(eta) == list:
            data = runGDANPlayer(game_list, vals, numsteps=numsteps, eta=eta[i], graph=graph, N=players)
        else:
            data = runGDANPlayer(game_list, vals, numsteps=numsteps, eta=eta, graph=graph, N=players)

        time_avg = data['timeavg'][0]
        data_1 = data['vals'][0]
        dist_vals = np.log([np.linalg.norm(data_1[k] - data_1[-1]) + 0.000001 for k in range(len(data_1))]) #time_avg
        dists.append(dist_vals)
        data_all.append(data)

    plt.figure(figsize=(12, 9), dpi=100)
    labels = ['Game '+ str(int(l)) + ' ($\eta=$' + str(eta[int(l)-1]) + ')' for l in np.linspace(1, num_simulations, num_simulations)]
    plt.xlabel('Time', fontsize=25)
    plt.ylabel('$\log(dist^2(x_t, \mathcal{X}^*))$', fontsize=25)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.plot(np.array(dists).T)
    plt.legend(labels, loc="best", fontsize=25)
    plt.show()
    return data_all

def FindBestEta(num_etas, dimension, graph, players, numsteps, seed, endpoints = (0.001, 0.9)):
    '''
    Function to find the best stepsize for a given seed.
    
        Parameters:
            num_etas (int): number of stepsize values to run search on
            seed (float): seed value to test
            endpoints (tuple): range of values to obtain stepsize values from
    '''
    dists=[]
    leg=[]
    game = GetRandGame(dimension, seed=seed)
    game_list = [game for j in range(players)]
    etas = np.linspace(endpoints[0], endpoints[1], num_etas)
    for i in etas:
        vals = np.random.rand(players, 1, dimension)
        for j in range(vals.shape[0]):
            vals[j] = vals[j]/(np.sum(vals[j]))
        vals = np.repeat(vals, 2, axis=1)
        data = runGDANPlayer(game_list, vals, numsteps=numsteps, eta=i, graph=graph, N=players, _print=False)

        time_avg = data['timeavg'][0]
        data_1 = data['vals'][0]
        dist_vals = np.log([np.linalg.norm(data_1[k] - data_1[-1]) +0.000001 for k in range(len(data_1))])
        if all((data['timeavg'].flatten() > -10) & (data['timeavg'].flatten() < 10)) :
            dists.append(dist_vals)
            leg.append(i)

    plt.figure()
    plt.plot(np.array(dists).T)
    
    plt.legend(leg, loc="best")
    plt.show()
    return np.array(dists).T

def PlotTimeAvg(data, dimension, plot=True, player=0, num_to_plot=1000):
    '''
    Function to compute and plot the time average values of a given dataset.
    '''
    cumsums = []
    for i in range(dimension):
        data_plot = data['vals'][player][:,i] - data['timeavg'][player][i]
        cumsums.append((np.cumsum(data_plot)/np.arange(1, len(data_plot)+1))[:num_to_plot])
    if plot:
        plt.figure(figsize=(10, 8), dpi=100)
        plt.xlabel('Time', fontsize=20)
        plt.ylabel('Time Average Strategy', fontsize=20)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.plot(np.array(cumsums).T)
        plt.show()
    
    return np.array(cumsums)