# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt

### Global Parameters 

Epsilon=0.1 #Value of Epsilon for the 1S-RCMOT problem
v=0.25 #in h^(-1) charging speed for the EVs; An EV arriving with 0% of battery will be fully charged in 4h.

NTime=8 #Number of hours
Nt=3*NTime+1 #Number of points for the discretization in time
Ns=20 #Number of points for the discretization in space
Td=np.linspace(0,NTime,Nt)+9 #Discretization of time
cT=np.zeros((Nt,Nt,Nt,Nt)) #Discretization of the cost function c in the 1S-RCMOT problem
for tx in range(Nt):
    for ty in range(Nt):
        for t1 in range(Nt):
            for t2 in range(Nt):
                cT[tx,ty,t1,t2]=(Td[tx]-Td[ty])**2+(Td[tx]-Td[t2])**2

fT=np.zeros((Nt,Nt,Ns,Nt,Nt,Nt)) #Discretization of the general moment f in the 1S-RCMOT problem, representing the global consumption
for ta in range(Nt):
    print("Computing the discretization of the generalized moment function: " + str(ta)+"/"+str(Nt-1))
    for s in range(Ns):
        for tc in range(Nt):
            for td1 in range(Nt):
                for tc2 in range(Nt):
                    for T in range(tc,np.min([Nt,int(tc+Nt*(1-s/(Ns-1))/(v*NTime)),td1])):
                        fT[T,ta,s,tc,td1,tc2]=1
                    for T in range(tc2,np.min([Nt,int(tc2+tc-td1+Nt*(1-s/(Ns-1))/(v*NTime))])):
                        fT[T,ta,s,tc,td1,tc2]=1

### Definition of mu1 and mu2

mu0=np.zeros((Nt,Ns)) #Initial law for the arrival of EVs. 

for t in range(int(Nt/4)):
    for s in range(int(Ns*0.05),int(Ns*0.95)):
        mu0[t,s]=1
mu0=mu0/np.sum(mu0)

mu1=np.zeros((Nt,Ns,Nt,Nt,Nt))
mu1S=np.zeros((Nt,Ns,Nt))
for t in range(0,Nt):
    for s in range(0,Ns):
        mu1[t,s,t,t,t]=mu0[t,s]
        mu1S[t,s,t]=mu0[t,s]

mu2=np.zeros((Nt,Ns,Nt,Nt,Nt))

for t in range(0,Nt):
    for s in range(0,Ns):
        for tc in range(t,Nt):
            for td1 in range(tc,Nt):
                for tc2 in range(td1,np.min([int(Nt*(NTime+td1-tc-(1-s/Ns)/v)/NTime),Nt])):
                    mu2[t,s,tc,td1,tc2]=1
        if np.sum(mu2[t,s])!=0:
            mu2[t,s]=mu2[t,s]*mu0[t,s]/(np.sum(mu2[t,s]))

### Gradient

def gradJ(lbda):
    S=np.tensordot(lbda,fT,axes=[0,0])
    B=np.tensordot(np.exp(S)*mu2,np.exp(-cT/Epsilon),axes=[(2,3,4),(1,2,3)])
    fB=np.tensordot(fT*np.exp(S)*mu2,np.exp(-cT/Epsilon),axes=[(3,4,5),(1,2,3)])
    G=np.tensordot(np.divide(fB, B, out=np.zeros_like(fB), where=B!=0),mu1S,axes=[(1,2,3),(0,1,2)])
    return(G)

### Algo

lbd0=np.zeros(Nt) # Starting lambda
R2=np.zeros(Nt)+0.2 # Value to be achieved for overall consumption
rangeR=range(3,4+3*2) # Range of time during which the previous value must be reached by the overall consumption

def prho(k):
    return(min(100,3000/((1+k)**0.7)))

def Algo(R2,rangeR,Klimit=1000,plot=True,prin=True):
    lbda=lbd0.copy()
    k=0
    G=lbd0.copy()+1
    while k<Klimit and np.linalg.norm(G[rangeR]-R2[rangeR])>0.0001:
        G=gradJ(lbda)
        lbda[rangeR]+=-prho(k)*(G[rangeR]-R2[rangeR])
        k+=1
        if plot and k%10==0:
            DisplayConsumption(G,R2,rangeR)
        if prin:
            print("Step "+str(k)+ " " +str(np.linalg.norm(G[rangeR]-R2[rangeR])))
    return(lbda)

###Display

def DisplayConsumption(G,R2,rangeR):
    plt.plot(Td,np.tensordot(fT,mu1,axes=[(1,2,3,4,5),(0,1,2,3,4)]),label='Nominal consumption')
    plt.plot(Td,G,label="Optimized consumption")
    plt.plot(Td[rangeR],R2[rangeR],label="Constraint")
    plt.xlabel("Time")
    plt.ylabel("Aggregated consumption")
    plt.grid()
    plt.legend()
    plt.show()

def MuLambda(lbda):
    S=np.tensordot(lbda,fT,axes=[0,0])
    B=np.tensordot(np.exp(S)*mu2,np.exp(-cT/Epsilon),axes=[(2,3,4),(1,2,3)])
    D=np.divide(mu1S, B, out=np.zeros_like(mu1S), where=B!=0)
    MuL=mu2*np.exp(S)*np.tensordot(D,np.exp(-cT/Epsilon), axes=[2,0])
    return(MuL)
    