import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits import mplot3d
from scipy.integrate import odeint, solve_ivp
from scipy.stats import entropy
from scipy.linalg import sinm, cosm, logm
from scipy.linalg import expm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
from sympy.physics.quantum import TensorProduct
from tqdm import tqdm
import cmath
import warnings
import os
import imageio
import copy
from pylab import *
from qutip import *

def diag_block_mat_boolindex(L):
    shp = L[0].shape
    mask = np.kron(np.eye(len(L)), np.ones(shp))==1
    out = np.zeros(np.asarray(shp)*len(L),dtype=complex)
    out[mask] = np.concatenate(L).ravel()
    return out

def GetRMatrix(payoff):
    '''
    Generate square matrix with classical payoff values on the diagonal.
    '''
    a = [np.diag(row) for row in payoff]
    R = diag_block_mat_boolindex(a)
    return R

def PlotTrajectories(data):
    '''
    Plot the trajectories given a dictionary of values obtained via MMWU.
    '''
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10,10))
    fig.suptitle('x and y strategies over time', size=15)
    ax1.set_title("x player")
    ax2.set_title("y player")
    rho_vals = [np.diag(i) for i in data['rho']]
    x_data = []
    for i in range(len(rho_vals[0])):
        x_data.append([row[i] for row in rho_vals])
        x_data_real = np.array([row[i] for row in rho_vals]).real
        ax1.plot(x_data_real, label='Strategy '+str(i+1))
    sig_vals = [np.diag(i) for i in data['sig']]
    y_data = []
    for i in range(len(sig_vals[0])):
        y_data.append([row[i] for row in sig_vals])
        y_data_real = np.array([row[i] for row in sig_vals]).real
        ax2.plot(y_data_real, label='Strategy '+str(i+1))
        
    ax1.legend()
    ax2.legend()
    return x_data, y_data

def sample_unit(npoints, ndim=3):
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return (vec)

def sample_theta(npoints):
    vec = np.random.uniform(low=0.0, high=2*np.pi, size=npoints)
    return vec

def PlotTrajectoriesComplex(data):
    '''
    Plot trajectories of replicator data on complex plane.
    '''
    fig = plt.figure(figsize=(20,20))
    ax1 = fig.add_subplot(2, 1, 1, projection='3d')
    ax2 = fig.add_subplot(2, 1, 2, projection='3d')
    fig.suptitle("Argand Diagram for x and y player's strategy", size=15)
    ax1.set_title("x player")
    ax2.set_title("y player")
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Real')
    ax1.set_zlabel('Imaginary')
    ax1.set_xlabel('Time')
    ax2.set_ylabel('Real')
    ax2.set_zlabel('Imaginary')
    
    t = np.array(range(0,len(data['rho'])))

    rho_vals = [np.sum(i-np.diag(np.diag(i)), axis=1) for i in data['rho']]
    x_data = []
    for i in range(len(rho_vals[0])):
        x_data_new = [row[i] for row in rho_vals]
        x_data.append(x_data_new)
        real_x = np.array(x_data_new).T.real
        imag_x = np.array(x_data_new).T.imag
        ax1.plot3D(t, real_x, imag_x, label='Strategy '+str(i+1))
    ax1.grid()
    
    sig_vals = [np.sum(i-np.diag(np.diag(i)), axis=1) for i in data['sig']]
    y_data = []
    for i in range(len(sig_vals[0])):
        y_data_new = [row[i] for row in sig_vals]
        y_data.append(y_data_new)
        real_y = np.array(y_data_new).T.real
        imag_y = np.array(y_data_new).T.imag
        ax2.plot3D(t, real_y, imag_y, label='Strategy '+str(i+1))
    ax2.grid()
    ax1.view_init(45, -90);
    ax2.view_init(45, -90);
    ax1.legend()
    ax2.legend()
    return x_data, y_data

def GetDensityInit(num_inits, dim=2):
    densities = []
    for i in range(num_inits):
        rand_r = abs(sample_unit(1, ndim=dim))
        theta = sample_theta(dim)
        rho = np.array([cmath.rect(rand_r[j][0], theta[j]) for j in range(dim)])
        density = np.outer(rho, rho.conjugate())
        densities.append(density)
    return np.array(densities)

def GetDensityOperator(probs=np.array([1/2, 1/2]), dim=2):
    brakets = GetDensityInit(dim, dim=dim)
    return np.sum([brakets[i]*probs[i] for i in range(len(brakets))], axis=0)

def num_after_point(x):
    s = str(x)
    if not '.' in s:
        return 0
    return len(s) - s.index('.') - 1

def split_list(_list):
    half = len(_list)//2
    return _list[:half], _list[half:]
    
def qre(rho, sigma):
    return np.trace(rho@(logm(rho)-logm(sigma)))

def PlotQRE(data, nash1=np.array([[0.5, 0.], [0., 0.5]]), nash2=np.array([[0.5, 0.], [0., 0.5]]), num=1000, discrete=False):
    '''
    Obtain and plot the total quantum relative entropy of the system, given the Nash equilbrium of the game.
    '''
    div_x = []
    if not discrete:
        x_data = data['x']
        y_data = data['y']
    else:
        x_data = data['rho']
        y_data = data['sig']
    nash1 = nash1/np.trace(nash1)
    nash2 = nash2/np.trace(nash2)
    for i in x_data[:num]:
        i = i/np.trace(i)
        kl_div_x = qre(rho=nash1, sigma=i) #compute quantum relative entropy between nash and x data
        div_x.append(kl_div_x.real)
    
    div_y = []
    for i in y_data[:num]:
        i = i/np.trace(i)
        kl_div_y = qre(rho=nash2, sigma=i)
        div_y.append(kl_div_y.real)

    div_combined = np.add(div_x, div_y)
    
    fig = go.Figure([go.Scatter(y=div_y[:1000],
                    mode='lines', line=dict(width=0.5, color='darkslateblue'),
                    name='y', fill='tozeroy'), 
                 go.Scatter(y=div_combined[:1000],
                    mode='lines',
                    name='x', line=dict(width=0.5, color='plum'), fill='tonexty'),
                 go.Scatter(y=div_combined[:1000],
                    mode='lines',
                    name='Sum of entropies', line = dict(width = 3, color='#440154'), opacity=1),
                ])

    # Edit layout
    fig.update_layout(title='Constant of Motion',
                      xaxis_title='Time Steps',
                      yaxis_title='Quantum Relative Entropy',
                      legend_orientation='h', 
                      legend=dict( y=-0.2),
                      font=dict(size=15))
    fig.show()
    
def GetComplexR(basis1, basis2, payoff):
    '''
    Obtain complex R matrix via basis transform as per Lemma 5 in the Appendix of the paper.
    '''
    dimension = payoff.shape[0]
    R = np.zeros((dimension**2, dimension**2), dtype=complex)
    ntbasis = (1/2)*np.matrix([[1, 1, 1, 1], [1, 1, -1, -1], [1, -1, -1, 1], [-1, 1, -1, 1]]).T
    R2 = payoff[0][0]*(ntbasis[:,0]@np.conj(ntbasis[:,0]).T) + payoff[0][1]*(ntbasis[:,1]@np.conj(ntbasis[:,1]).T) +\
        payoff[1][0]*(ntbasis[:,2]@np.conj(ntbasis[:,2]).T) + payoff[1][1]*(ntbasis[:,3]@np.conj(ntbasis[:,3]).T)
    
    for i in range(dimension):
        for j in range(dimension):
            R += payoff[i][j]*((np.kron(basis1[:,i], basis2[:,j]))@(np.kron(np.conj(basis1[:,i]).T, np.conj(basis2[:,j]).T)))
    print('R matrix: ', R)
    return R

def GetTransformedInit(basis1, basis2, s):
    '''
    Obtain initial conditions for simulation via same basis transform as R matrix.
    '''
    dimension = int(len(s)/2)
    s_return = []
    s_return2 = []
    for i in range(dimension):
        s_return.append(s[i]*basis1[:,i]@basis1[:,i].conj().T)
        s_return2.append(s[i]*basis2[:,i]@basis2[:,i].conj().T)
    if not np.all(np.linalg.eigvals(s_return) > 0) and np.all(np.linalg.eigvals(s_return2) > 0):
        warnings.warn("Transformed init not PD")
    return np.concatenate([np.sum(s_return, axis=0).flatten(), np.sum(s_return2, axis=0).flatten()]).tolist()

def GetBasisTransform(payoff, basis1=np.matrix([[1, 0], [0, 1]]), basis2=np.matrix([[1, 0], [0, 1]]), 
                      basis3=np.matrix([[1, 0], [0, 1]]), basis4=np.matrix([[1, 0], [0, 1]]),
                      s=[0.25+0j, 0.75+0j, 0.75+0j, 0.25+0j], nash1 = np.array([1/2, 1/2]), 
                      nash2 = np.array([1/2, 1/2])):
    data={}
    data['R'] = GetComplexR(basis1, basis2, payoff)    
    data['s'] = GetTransformedInit(basis3, basis4, s)    
    return data

def CreateFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print ('Error: Creating directory. ' +  directory)
        
def PrintBloch(states1, states2, filename='bloch', dirc = 'tmp', timestep=0, save=False):
    '''
    Prints individual snapshot of a Bloch sphere at a given time-step.
    '''
    b = Bloch()
    b.view = [-40,30]
    b.point_color = ['cornflowerblue','crimson']
    b.point_marker = ['o']
    b.point_size = [20]
    b.vector_color = ['cornflowerblue', 'crimson']
    b.vector_width = 6
    
    b.clear()
    b.add_states([states1[:timestep+1]], 'point')
    b.add_states([states2[:timestep+1]], 'point')
    b.add_states(states1[timestep])
    b.add_states(states2[timestep])
    if save:
        CreateFolder(dirc)
        file_name = dirc+ '/'+ filename + '.png'
        b.save(file_name)
    b.show()
    
def AnimateBloch(states1, states2, duration=0.1, save_all=False, filename='bloch_anim', dirc='tmp', length=50):
    '''
    Saves an animation of quantum replicator trajectories on a Bloch sphere.
    '''
    file_name = filename + '.gif'
    b = Bloch()
    b.view = [-40,30]
    images=[]
    nrm = mpl.colors.Normalize(0,length)
    colors1 = cm.cool(nrm(range(length)))
    colors2 = cm.autumn(nrm(range(length)))
    b.point_color = ['cornflowerblue','crimson']
    b.point_marker = ['o']
    b.point_size = [20]
    b.vector_color = ['cornflowerblue', 'crimson']
    b.vector_width = 6
    
    for i in range(length):
        b.clear()
        b.add_states([states1[:(i+1)]], 'point')
        b.add_states([states2[:(i+1)]], 'point')
        b.add_states(states1[i])
        b.add_states(states2[i])
        if save_all:
            CreateFolder(dirc)
            imgfilename=dirc+'/bloch_%01d.png' % i
            b.save(imgfilename)
        else:
            imgfilename='temp_file.png'
            b.save(imgfilename)
        images.append(imageio.imread(imgfilename))
    imageio.mimsave(file_name, images, duration=duration)