# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import time
from scipy.special import lambertw

"""
To simplify the algorithm, the moments concerning consumption and the consumption gradient are distinguished.
For consumption, we denote f,R and lbda the corresponding function, signal and Lagrange multiplier.
For the gradient, we denote g,S and beta the corresponding function, signal and Lagrange multiplier.
The problem without gradient constraints corresponds to beta=0
"""

### Global Parameters 

Epsilon=0.1 #Value of Epsilon for the 1S-RCMOT problem

NTime=10 #Number of hours
Nt=101 #4*NTime+1 #Number of points for the discretization in time
dt=NTime/(Nt-1)
Ns=25 #Number of points for the discretization in space
ds=1/(Ns-1)

Td=np.linspace(0,NTime,Nt) #Discretization of time

cT=np.zeros((Nt,Nt)) #Discretization of the cost function c in the 1S-RCMOT problem
for tx in range(Nt):
    for ty in range(Nt):
        cT[tx,ty]=(tx*NTime/Nt-ty*NTime/Nt)**2

vita=2/5
dropout=0.75

def b(i,s):
    if i==1:
        if s<dropout:
            return(vita)
        else:
            return(vita*(1-(s-dropout)/(1-dropout)))
    if i==0:
        return(0)
    if i==2:
        return(-b(1,1-s))

fT=np.zeros((Nt,Ns,Nt)) #Discretization of the general moment f in the 1S-RCMOT problem, representing the global consumption
for s in range(Ns):
    for tc in range(Nt):
        if s*ds<=dropout:
            t0=int(tc+1+max(0,dropout-s*ds)/(vita*dt))
            for T in range(tc,min(Nt,t0)):
                fT[T,s,tc]+=vita
            for T in range(t0,Nt):
                fT[T,s,tc]+=vita*np.exp(4*vita*(t0*dt-T*dt))
        else:
            for T in range(tc,Nt):
                fT[T,s,tc]+=vita*4*(1-s*ds)*np.exp(4*vita*(tc*dt-T*dt))

mu0=np.zeros(Ns) #Initial law for the arrival of EVs.

for s in range(int(0.15/ds)+1,Ns):
    mea=0.15
    sigma=0.2
    mu0[s]=(1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(-((s/(Ns-1) - mea)**2) / (2 * sigma**2))

mu0=mu0/np.sum(mu0)

mu1=np.zeros((Ns,Nt))

for t in range(0,Nt):
    for s in range(0,Ns):
        mu1[s,0]=mu0[s]


mu2=np.zeros((Ns,Nt))

treshold=0.75
for s in range(0,Ns):
    for tc in range(0,Nt-max(0,int((treshold-s*ds)/(vita*dt)))):
        mu2[s,tc]=1
    if np.sum(mu2[s])!=0:
        mu2[s]=mu2[s]*mu0[s]/(np.sum(mu2[s]))

###Definition of the gradient functions

def gradJL(lbda,beta):
    S=np.tensordot(lbda,fT,axes=[0,0])#+np.tensordot(beta,gT,axes=[0,0])
    B=np.tensordot(np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[1,1])
    fB=np.tensordot(fT*np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[2,1])
    return(np.tensordot(np.divide(fB, B, out=np.zeros_like(fB), where=B!=0),mu1,axes=[(1,2),(0,1)]))

### Algorithm

R2=np.zeros(Nt)+0.2 # Value to be achieved for overall consumption
rangeR=range(3,4+3*2) # Range of time during which the previous value must be reached by the overall consumption
S0=np.zeros(Nt-1)+0.2*NTime/(Nt-1) # Bounds for the gradient overall consumption
#Dd2=np.array([np.load("Save\Dd2Test2Sinus.npy")[5*i] for i in range(80)])
b0=0.0035
Dd2=b0+(0.2-b0)*np.exp(-Td*0.2803635831824254)

lbd0=np.zeros(Nt) # Starting lbda
beta0=np.zeros(Nt-1) # Starting beta

def prho(k):
    return((101/Nt)*25/((1+k)**0.67))

def Algo(R2,rangeR,Klimit=np.inf,plot=True,prin=True,GradientControl=True,Conv=10**(-3),plotEnd=False):
    lbda=lbd0.copy()
    beta=beta0.copy()
    k=0
    GL=lbd0.copy()+1
    while k<Klimit and np.linalg.norm(GL[rangeR]-R2[rangeR])/np.sqrt(Nt)>Conv :
        GL=gradJL(lbda,beta)
        lbda[rangeR]+=prho(k)*(GL[rangeR]-R2[rangeR])
        k+=1
        if plot and k%10==0:
            DisplayConsumption(GL,R2,rangeR)
        if prin:
            print("Step "+str(k)+ " " +str(np.linalg.norm(GL[rangeR]-R2[rangeR])/np.sqrt(Nt)))
    if plotEnd:
        DisplayConsumption(GL,R2,rangeR)#DisplayConsumption(GL,R2,rangeR)
    print(str(np.linalg.norm(GL[rangeR]-R2[rangeR])/np.sqrt(Nt)))
    return(lbda)


def AlgoN(R2,rangeR,Klimit=np.inf,plot=True,plotEnd=False,prin=True,GradientControl=True,Conv=10**(-3)):
    lbda=lbd0.copy()
    beta=beta0.copy()
    k=0
    f=fT-R2
    GL=lbd0.copy()+1
    E=np.exp(-f/Epsilon)
    K=np.exp(-cT/Epsilon)
    while k<Klimit and np.linalg.norm(GL[rangeR])/np.sqrt(Nt)>Conv :
        El=np.exp(np.tensordot(lbda,np.log(E),axes=[0,0]))*mu2
        Und=np.tensordot(El,K,axes=[1,0])
        u=np.divide(mu1, Und, out=np.zeros_like(mu1), where=Und!=0)
        GL=np.tensordot(np.dot(u,K),El*f,axes=[(0,1),(1,2)])
        lbda[rangeR]=lbda[rangeR]+Epsilon*GL[rangeR]/np.tensordot(np.dot(u,K),El*(f**2),axes=[(0,1),(1,2)])[rangeR]
        k+=1
        if plot and k%10==0:
            DisplayConsumption(GL+R2,R2,rangeR)
        if prin:
            print("Step "+str(k)+ " " +str(np.linalg.norm(GL[rangeR]/np.sqrt(Nt))))
    if plotEnd:
        DisplayConsumption(GL+R2,R2,rangeR)
    print(str(np.linalg.norm(GL[rangeR])))
    return(lbda,beta)

###Display

def DisplayConsumption(G,R2,rangeR,AdditionnalLbda=[]):
    plt.plot(Td,np.tensordot(fT,mu1,axes=[(1,2),(0,1)]),label='Nominal consumption')
    plt.plot(Td,G,label="Optimized consumption")
    plt.plot(Td[rangeR],R2[rangeR],label="Constraint")
    if len(AdditionnalLbda)>0:
        plt.plot(Td,np.tensordot(fT,MuLambda(AdditionnalLbda, beta0),axes=[(1,2),(0,1)]),label='Without Gradient constraint')
    plt.xlabel("Time")
    plt.ylabel("Aggregated consumption")
    plt.grid()
    plt.legend()
    plt.show()

def CompareTime(rangeConv):
    l1=[]
    l2=[]
    for i in range(len(rangeConv)):
        Conv=rangeConv[i]
        print(Conv)
        
        t0=time.time()
        lbda=AlgoN(Td*0+0.2,range(4,13),plot=False,prin=False,Conv=Conv)
        l1.append(time.time()-t0)
        
        t0=time.time()
        lbda=Algo(Td*0+0.2,range(4,13),plot=False,prin=False,Conv=Conv)
        l2.append(time.time()-t0)
        
        plt.plot(rangeConv[:i+1],l1,label="SinkN")
        plt.plot(rangeConv[:i+1],l2,label="Classic")        
        plt.xscale("log")
        plt.legend()
        plt.grid()
        plt.show()
    return(l1,l2)

def MuLambda(lbda,beta):
    S=np.tensordot(lbda,fT,axes=[0,0])#+np.tensordot(beta,gT,axes=[0,0])
    B=np.tensordot(np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[2,1])
    MuL=mu2*np.exp(-S/Epsilon)*np.dot(np.divide(mu1, B, out=np.zeros_like(mu1), where=B!=0),np.exp(-cT/Epsilon))
    return(MuL)   
