# -*- coding: utf-8 -*-


import numpy as np
import matplotlib.pyplot as plt
#import seaborn as sns
#from matplotlib import cm
import time
from scipy.special import rel_entr

#Epsilon=0.01

mean=0.25
var=(0.05)**2

def Init(Nt,Epsilon=0.01):
    mu=np.zeros(Nt)+1
    mu=mu/np.sum(mu)
    f=np.array([(np.arange(0,Nt,1))/(Nt-1)-mean]) #,((np.arange(0,Nt,1)/(Nt-1)-mean)**2)-var
    f=np.transpose(np.transpose(f)/np.max(np.abs(f),axis=1))
    E=np.exp(-f/Epsilon)
    c=np.zeros((Nt,Nt))
    for i in range(Nt):
        for j in range(Nt):
            c[i,j]=(i-j)**2
    c=c/np.max(c)
    K=np.exp(-c/Epsilon)
    return(mu,f,E,c,K)

def Algo(Klimit,Conv=10**(-3),Nt=100,plot=True,plotEnd=True,Epsilon=0.01):
    mu,f,E,c,K=Init(Nt,Epsilon=Epsilon)
    lbda=np.zeros(len(f))
    l=[]
    k=0
    G=np.inf
    u=np.ones(Nt)/Nt
    while k<Klimit and np.linalg.norm(G)>Conv:
        El=np.exp(np.dot(lbda,np.log(E)))
        u=mu/np.dot(K,El)
        G=np.tensordot(np.dot(u,K),El*f,axes=[0,1])
        #inv=np.linalg.inv(np.tensordot(np.dot(u,K),(El)*(np.moveaxis(f*np.ones((2,2,Nt)),0,1)*(f*np.ones((2,2,Nt)))),axes=[0,2]))
        #lbda=lbda+Epsilon*np.dot(inv,G)
        lbda=lbda+Epsilon*G/np.tensordot(np.dot(u,K),(El)*(f**2),axes=[0,1])
        #print(np.linalg.norm(G))
        l.append(np.linalg.norm(G))
        if plot:
            plt.plot(np.abs(l))
            plt.show()
        k+=1
    if plotEnd:
        Td=np.linspace(0,1,Nt)
        plt.plot(Td,Nt*mu,label="$\mu_1$")
        plt.xlabel("Interval [0,1]")
        El=np.exp(np.dot(lbda,np.log(E)))
        plt.grid()
        plt.legend()
        #plt.plot(u*np.dot(K,El))
        plt.plot(Td,Nt*np.dot(u,K)*(El),label=r"$\pi^*_{\varepsilon,2}$")
        plt.legend()
        plt.show()
    return(u,c,K,El)

def AlgoG(Klimit,Conv=10**(-3),Nt=100,plot=True,plotEnd=True,Epsilon=0.1):
    mu,f,E,c,K=Init(Nt)
    lbda=np.zeros(len(f))
    l=[]
    k=0
    u=np.ones(Nt)/Nt
    G=np.inf
    while k<Klimit and np.linalg.norm(G)>Conv:
        El=np.exp(np.dot(lbda,np.log(E)))
        G=np.tensordot(np.dot(u,K),El*f,axes=[0,1])
        u=mu/np.dot(K,El)
        lbda=lbda+G*prho(k)/Epsilon
        #print(np.linalg.norm(G))
        l.append(np.linalg.norm(G))
        if plot:
            plt.plot(np.abs(l))
            plt.show()
        k+=1
    if plotEnd:
        plt.plot(mu)
        El=np.exp(np.dot(lbda,np.log(E)))
        plt.plot(u*np.dot(K,El))
        plt.plot(np.dot(u,K)*(El))
        plt.show()
    return(u,lbda,k)

def prho(k):
    return(1/(k+1))

def Algo2(Klimit,Conv=10**(-3),Nt=100,plot=True,plotEnd=True,Epsilon=0.01):
    mu,f,E,c,K=Init(Nt)
    lbda=np.zeros(len(f))
    l=[]
    grad=+np.inf
    k=0
    while k<Klimit and np.linalg.norm(grad)>Conv:
        El=np.exp(np.dot(lbda,np.log(E)))*K
        grad=np.tensordot(mu,np.tensordot(f,El,axes=[1,1])/np.sum(El,1),axes=[0,1])
        lbda=lbda+(1/Epsilon)*prho(k)*grad
        l.append(np.linalg.norm(grad))
        if plot:
            print(np.linalg.norm(grad))
            plt.plot(np.abs(l))
            plt.show()
        k+=1
    if plotEnd:
        plt.plot(mu)
        El=np.exp(np.dot(lbda,np.log(E)))*K
        plt.plot(np.dot(mu/np.sum(El,1),El))
        plt.show()
    return(lbda,k)

def Phi(t):
    return(((t-1)**2)/2)

def PhiStar(t):
    return((t+(t**2)/2)*(t>=-1)+(-1/2)*(t<-1))

def DPhiStar(t):
    return((1+t)*(t>=-1))

def D2PhiStar(t):
    return(t>=-1)

def AffichagePhi(mu,Sum,Nt,f):
    pi2=np.dot(mu,DPhiStar(Sum))*mu
    Td=np.linspace(0,1,Nt)
    plt.plot(Td,Nt*mu,label="$\mu_1$")
    plt.xlabel("Interval [0,1]")
    plt.plot(Td,pi2*Nt,label=r"$\pi^*_{\varepsilon,2}$")
    plt.grid()
    plt.legend()
    plt.show()
    
def AlgoPhiDivergence(Klimit,Conv=10**(-3),Nt=100,plot=True,plotEnd=True,Epsilon=0.01):
    mu,f,E,c,K=Init(Nt)
    g=np.zeros(Nt)
    lbda=np.zeros(len(f))
    Sum=(-c+g+np.dot(lbda,f)[np.newaxis,...])/Epsilon
    k=0
    while k<Klimit and (np.linalg.norm(lbda)>Conv or np.abs(np.sum(np.dot(mu,DPhiStar(Sum))*mu)-1)>Conv):
        s=np.inf
        h=(-c+np.dot(lbda,f)[np.newaxis,...])/Epsilon
        gamma=np.log(np.dot(np.exp(h),mu))
        k1=0
        while np.linalg.norm(s)>Conv:
            k1+=1
            s=-prho(k1)*(np.dot(DPhiStar(h-gamma),mu)-1)/np.dot(D2PhiStar(h-gamma),mu)
            gamma=gamma-s
            #print("Conv Gamma: " +str (np.linalg.norm(s)))
        g=-gamma*Epsilon
        Sum=(-c+g+np.dot(lbda,f)[np.newaxis,...])/Epsilon
        H=np.dot(f[:,np.newaxis,:]*f,mu*np.dot(mu,D2PhiStar(Sum)))
        V=np.dot(f,mu*np.dot(mu,DPhiStar(Sum)))
        lbda-=Epsilon*np.dot(np.linalg.inv(H),V)
        AffichagePhi(mu,Sum,Nt,f)
        k+=1
        Sum=(-c+g+np.dot(lbda,f)[np.newaxis,...])/Epsilon
        AffichagePhi(mu,Sum,Nt,f)
        print(V)
    pi2=np.dot(mu,DPhiStar(Sum))*mu
    return(lbda,mu,pi2)

def AlgoPhiDivergence2(Klimit,Conv=10**(-3),Nt=100,plot=True,plotEnd=True,Epsilon=0.01):
    mu,f,E,c,K=Init(Nt)
    g=np.zeros(Nt)
    lbda=np.zeros(len(f))
    k=0
    while k<Klimit: #and np.linalg.norm(lbda)>Conv
        Sum=(np.transpose(-c+np.dot(lbda,f))+g)/Epsilon
        g=np.log(np.dot(np.exp(Sum),mu))
        while np.abs(np.sum(np.dot(mu,DPhiStar(Sum))*mu)-1)>Conv:
            Sum=(np.transpose(-c+np.dot(lbda,f))+g)/Epsilon
            g+=-Epsilon*(np.dot(DPhiStar(Sum),mu)-1)/np.dot(D2PhiStar(Sum),mu)
            AffichagePhi(mu,Sum,Nt,f)
        V=np.dot(f,mu*np.dot(mu,DPhiStar(Sum)))
        k1=0
        while np.linalg.norm(V)>Conv:
            Sum=(np.transpose(-c+np.dot(lbda,f))+g)/Epsilon
            H=np.dot(f[:,np.newaxis,:]*f,mu*np.dot(mu,D2PhiStar(Sum)))
            V=np.dot(f,mu*np.dot(mu,DPhiStar(Sum)))
            lbda-=Epsilon*np.dot(np.linalg.inv(H),V)
            AffichagePhi(mu,Sum,Nt,f)
            k1+=1
        k+=1
        Sum=(np.transpose(-c+np.dot(lbda,f))+g)/Epsilon
        AffichagePhi(mu,Sum,Nt,f)
        pi2=np.dot(mu,DPhiStar(Sum))*mu
    return(lbda,mu,pi2)

def CompareTime(rangeNt,Conv=10**(-3)):
    l1=[]
    l2=[]
    l3=[]
    for i in range(len(rangeNt)):
        Nt=int(rangeNt[i])
        print(Nt)
        
        t0=time.time()
        u,c,K,El,k=Algo(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l1.append(time.time()-t0)
        
        t0=time.time()
        lbda=Algo2(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l2.append(time.time()-t0)
        
    plt.plot(rangeNt[:i+1],l2,label="Gradient Descent")
    plt.plot(rangeNt[:i+1],l1,label="Newton Method")
    plt.xlabel("Number of discretization points")
    plt.ylabel("Time (in seconds)")
    plt.legend()
    plt.grid()
    plt.show()
    return(l1,l2)

def CompareTimeConv(rangeConv,Nt=100):
    l1=[]
    l2=[]
    l3=[]
    for i in range(len(rangeConv)):
        Conv=rangeConv[i]
        print(Conv)
        
        t0=time.time()
        u,c,K,El,k=Algo(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l1.append(time.time()-t0)
        
        t0=time.time()
        lbda=Algo2(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l2.append(time.time()-t0)
        
    plt.plot(rangeConv,l2,label="Gradient Descent")
    plt.plot(rangeConv,l1,label="Newton Method")
    plt.xlabel("$\kappa$")
    plt.ylabel("Time (in seconds)")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.grid()
    plt.xlim(np.max(rangeConv), np.min(rangeConv))
    plt.show()
    return(l1,l2)

def CompareIteration(rangeNt,Conv=10**(-3)):
    l1=[]
    l2=[]
    for i in range(len(rangeNt)):
        Nt=int(rangeNt[i])
        print(Nt)
        
        u,c,K,El,k=Algo(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l1.append(k)
        
        lbda,k=Algo2(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l2.append(k)
        
    plt.plot(rangeNt[:i+1],l1,label="With Newton Method")
    plt.plot(rangeNt[:i+1],l2,label="With Gradient Descend")
    plt.xlabel("Number of discretization points")
    plt.ylabel("Number of iterations")
    plt.legend()
    plt.grid()
    plt.show()
    return(l1,l2)

def CompareIterationConv(rangeConv,Nt=100):
    l1=[]
    l2=[]
    l3=[]
    for i in range(len(rangeConv)):
        Conv=rangeConv[i]
        print(Conv)
        
        u,c,K,El,k=Algo(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l1.append(k)
        
        lbda,k=Algo2(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv)
        l2.append(k)
        
    plt.plot(rangeConv,l2,label="Gradient Descent")
    plt.plot(rangeConv,l1,label="Newton Method")
    plt.xlabel("$\kappa$")
    plt.ylabel("Number of iterations")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.grid()
    plt.xlim(np.max(rangeConv), np.min(rangeConv))
    plt.show()
    return(l1,l2)

def CompareEps(rangeEps,Conv=10**(-3),Nt=100):
    l=[]
    ld=[]
    mu=np.zeros(Nt)+1
    mu=mu/np.sum(mu)
    u,c,K,El=Algo(1000,plot=False,plotEnd=False,Epsilon=np.min(rangeEps))
    if mean!=0.5:
        pi0=np.dot(np.dot(np.diag(u),K),np.diag(El))
        pi0=pi0/np.sum(pi0)
    else:
        pi0=np.diag(np.ones(Nt))/Nt
    dotmu=np.ones((Nt,Nt))*mu*mu
    dotmu=dotmu/np.sum(dotmu)
    for Eps in range(len(rangeEps)):
        u,c,K,El=Algo(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv,Epsilon=rangeEps[Eps])
        l.append(np.dot(np.dot(u,c*K),El))
        ld.append(np.dot(np.dot(u,c*K),El)+rangeEps[Eps]*np.sum(rel_entr(np.dot(np.dot(np.diag(u),K),np.diag(El)), dotmu)))
    plt.plot(rangeEps[:Eps+1],l,label=r"$\langle c,\pi^*_\varepsilon\rangle$")
    plt.plot(rangeEps[:Eps+1],ld,label=r"$d^*_\varepsilon$")
    plt.plot(rangeEps[:Eps+1],(Eps+1)*[(0.5-mean)**2],label="$d^*$")
    plt.plot(rangeEps[:Eps+1],np.tensordot(pi0,c,axes=[(0,1),(0,1)])+np.sum(rel_entr(pi0, dotmu))*rangeEps[:Eps+1],label="Bound")
    plt.xlabel(r"$\varepsilon$")
    plt.ylabel("Cost")
    plt.xscale("log")
    #plt.yscale("log")
    plt.legend()
    plt.grid()
    plt.xlim(np.max(rangeEps[:Eps+1]), np.min(rangeEps[:Eps+1]))
    plt.ylim(-0.02,np.max(l)*1.1)
    plt.show()
    return(l,ld)

def CompareEpsPlot(rangeEps,Conv=10**(-3),Nt=100):
    Td=np.linspace(0,1,Nt)
    mu=np.zeros(Nt)+1
    mu=mu/np.sum(mu)
    plt.plot(Td,Nt*mu,label="$\mu_1$")
    for Eps in range(len(rangeEps)):
        u,c,K,El=Algo(np.inf,Nt=Nt,plot=False,plotEnd=False,Conv=Conv,Epsilon=rangeEps[Eps])
        plt.plot(Td,Nt*np.dot(u,K)*(El),label=r"$\varepsilon=$"+str(rangeEps[Eps]))
        plt.xlabel("[0,1]")
        plt.ylabel("Density")
        #plt.ylim(0,1)
    plt.legend()
    plt.grid()
    plt.show()
    return()