# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import time
from scipy.special import lambertw

"""
To simplify the algorithm, the moments concerning consumption and the consumption gradient are distinguished.
For consumption, we denote f,R and lbda the corresponding function, signal and Lagrange multiplier.
For the gradient, we denote g,S and beta the corresponding function, signal and Lagrange multiplier.
The problem without gradient constraints corresponds to beta=0
"""


### Global Parameters 

Epsilon=0.03 #Value of Epsilon for the 1S-RCMOT problem
v=0.25 #in h^(-1) charging speed for the EVs; An EV arriving with 0% of battery will be fully charged in 4h.

NTime=8 #Number of hours
Nt=3*NTime+1 #Number of points for the discretization in time
Ns=20 #Number of points for the discretization in space

Td=np.linspace(0,NTime,Nt)+9 #Discretization of time

cT=np.zeros((Nt,Nt)) #Discretization of the cost function c in the 1S-RCMOT problem
for tx in range(Nt):
    for ty in range(Nt):
        cT[tx,ty]=(tx*NTime/Nt-ty*NTime/Nt)**2

fT=np.zeros((Nt,Nt,Ns,Nt)) #Discretization of the general moment f in the 1S-RCMOT problem, representing the global consumption
for ta in range(Nt):
    for s in range(Ns):
        for tc in range(Nt):
            for T in range(tc,min(Nt,int(tc+Nt*(1-s/(Ns-1))/(v*NTime)))):
                fT[T,ta,s,tc]+=1

D=np.zeros((Nt-1,Nt))
for i in range(Nt-1):
    D[i,i]=-1
    D[i,i+1]=1
gT=np.tensordot(D,fT,axes=[-1,0])  #Discretization of the second general moment in the 1S-RCMOT problem, representing the gradient of the global consumption


### Definition of mu1 and mu2

mu0=np.zeros((Nt,Ns)) #Initial law for the arrival of EVs. 

for t in range(int(Nt/4)):
    for s in range(int(Ns*0.05),int(Ns*0.95)):
        mu0[t,s]=1
mu0=mu0/np.sum(mu0)

mu1=np.zeros((Nt,Ns,Nt))

for t in range(0,Nt):
    for s in range(0,Ns):
        mu1[t,s,t]=mu0[t,s]

mu2=np.zeros((Nt,Ns,Nt))

for t in range(0,Nt):
    for s in range(0,Ns):
        for tc in range(t,int(Nt*(NTime-(1-s/Ns)/v)/NTime)):
            mu2[t,s,tc]=1
        if np.sum(mu2[t,s])!=0:
            mu2[t,s]=mu2[t,s]*mu0[t,s]/(np.sum(mu2[t,s]))

###Definition of the gradient functions


def gradJL(lbda,beta):
    S=np.tensordot(lbda,fT,axes=[0,0])+np.tensordot(beta,gT,axes=[0,0])
    B=np.tensordot(np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[2,1])
    fB=np.tensordot(fT*np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[3,1])
    return(np.tensordot(np.divide(fB, B, out=np.zeros_like(fB), where=B!=0),mu1,axes=[(1,2,3),(0,1,2)]))

def gradJB(lbda,beta):
    S=np.tensordot(lbda,fT,axes=[0,0])+np.tensordot(beta,gT,axes=[0,0])
    B=np.tensordot(np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[2,1])
    gB=np.tensordot(gT*np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[3,1])
    return(-np.sign(beta)*S0+np.tensordot(np.divide(gB, B, out=np.zeros_like(gB), where=B!=0),mu1,axes=[(1,2,3),(0,1,2)]))

### Algorithm

R2=np.zeros(Nt)+0.2 # Value to be achieved for overall consumption
rangeR=range(3,4+3*2) # Range of time during which the previous value must be reached by the overall consumption
S0=np.zeros(Nt-1)+0.2*NTime/(Nt-1) # Bounds for the gradient overall consumption

lbd0=np.zeros(Nt) # Starting lbda
beta0=np.zeros(Nt-1) # Starting beta

def prho(k):
    return(min(1,50/((1+k)**0.7)))

def Algo(R2,rangeR,Klimit=np.inf,plot=True,prin=True,GradientControl=True,Conv=10**(-3)):
    lbda=lbd0.copy()
    beta=beta0.copy()
    k=0
    GL=lbd0.copy()+1
    while k<Klimit and np.linalg.norm(GL[rangeR]-R2[rangeR])>Conv :
        GL=gradJL(lbda,beta)
        if GradientControl:
            beta+=prho(k)*gradJB(lbda, beta)
            beta[rangeR[:-1]]=0*beta[rangeR[:-1]]
        lbda[rangeR]+=prho(k)*(GL[rangeR]-R2[rangeR])
        k+=1
        if plot and k%10==0:
            DisplayConsumption(GL,R2,rangeR)
        if prin:
            print("Step "+str(k)+ " " +str(np.linalg.norm(GL[rangeR]-R2[rangeR])))
    #DisplayConsumption(GL,R2,rangeR)
    print(str(np.linalg.norm(GL[rangeR]-R2[rangeR])))
    return(lbda,beta)


def AlgoN(R2,rangeR,Klimit=np.inf,plot=True,prin=True,GradientControl=True,Conv=10**(-3)):
    lbda=lbd0.copy()
    beta=beta0.copy()
    k=0
    f=fT-R2
    GL=lbd0.copy()+1
    E=np.exp(-f/Epsilon)
    K=np.exp(-cT/Epsilon)
    while k<Klimit and np.linalg.norm(GL[rangeR])>Conv :
        El=np.exp(np.tensordot(lbda,np.log(E),axes=[0,0]))*mu2
        Und=np.tensordot(El,K,axes=[2,0])
        u=np.divide(mu1, Und, out=np.zeros_like(mu1), where=Und!=0)
        GL=np.tensordot(np.dot(u,K),El*f,axes=[(0,1,2),(1,2,3)])
        lbda[rangeR]=lbda[rangeR]+Epsilon*GL[rangeR]/np.tensordot(np.dot(u,K),El*(f**2),axes=[(0,1,2),(1,2,3)])[rangeR]
        #if GradientControl:
        #    beta+=prho(k)*gradJB(lbda, beta)
        #    beta[rangeR[:-1]]=0*beta[rangeR[:-1]]
        k+=1
        if plot and k%10==0:
            DisplayConsumption(GL+R2,R2,rangeR)
        if prin:
            print("Step "+str(k)+ " " +str(np.linalg.norm(GL[rangeR])))
    #DisplayConsumption(GL+R2,R2,rangeR)
    print(str(np.linalg.norm(GL[rangeR])))
    return(lbda,beta)

###Display

def DisplayConsumption(G,R2,rangeR,AdditionnalLbda=[]):
    plt.plot(Td,np.tensordot(fT,mu1,axes=[(1,2,3),(0,1,2)]),label='Nominal consumption')
    plt.plot(Td,G,label="Optimized consumption")
    plt.plot(Td[rangeR],R2[rangeR],label="Constraint")
    if len(AdditionnalLbda)>0:
        plt.plot(Td,np.tensordot(fT,MuLambda(AdditionnalLbda, beta0),axes=[(1,2,3),(0,1,2)]),label='Without Gradient constraint')
    plt.xlabel("Time")
    plt.ylabel("Aggregated consumption")
    plt.grid()
    plt.legend()
    plt.show()

def CompareTime(rangeConv):
    l1=[]
    l2=[]
    for i in range(len(rangeConv)):
        Conv=rangeConv[i]
        print(Conv)
        
        t0=time.time()
        lbda=AlgoN(Td*0+0.2,range(4,13),plot=False,prin=False,Conv=Conv)
        l1.append(time.time()-t0)
        
        t0=time.time()
        lbda=Algo(Td*0+0.2,range(4,13),plot=False,prin=False,Conv=Conv)
        l2.append(time.time()-t0)
        
        plt.plot(rangeConv[:i+1],l1,label="SinkN")
        plt.plot(rangeConv[:i+1],l2,label="Classic")        
        plt.xscale("log")
        plt.legend()
        plt.grid()
        plt.show()
    return(l1,l2)


def MuLambda(lbda,beta):
    S=np.tensordot(lbda,fT,axes=[0,0])+np.tensordot(beta,gT,axes=[0,0])
    B=np.tensordot(np.exp(-S/Epsilon)*mu2,np.exp(-cT/Epsilon),axes=[2,1])
    MuL=mu2*np.exp(-S/Epsilon)*np.dot(np.divide(mu1, B, out=np.zeros_like(mu1), where=B!=0),np.exp(-cT/Epsilon))
    return(MuL)   
