# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from sklearn.model_selection import train_test_split
#import itertools
#from scipy.sparse import csr_matrix

###Parameters 

Epsilon=0.1
NTime=24
dt=15
Nt=(60//dt)*NTime+1
NP=4
Td=np.linspace(0,NTime,Nt)

###Import Data

DataTransaction=pd.read_excel('./elaadnl_open_ev_datasets.xlsx',sheet_name="open_transactions")
#DataMeter=pd.read_excel('./elaadnl_open_ev_datasets.xlsx',sheet_name="open_metervalues")
#DataMeter['Hour'] = DataMeter['UTCTime'].dt.hour
#DataMeter['WeekDay'] = DataMeter['UTCTime'].dt.dayofweek
DataTransaction['WeekDay'] = DataTransaction['UTCTransactionStart'].dt.dayofweek
DataTransaction['ConnectedTime4'] = ((DataTransaction['ConnectedTime']*60)//dt)
DataTransaction['ChargeTime4'] = ((DataTransaction['ChargeTime']*60)//dt)

def Pow(lP):
    return(np.select([lP < 2, (2 <= lP) & (lP < 5), (5 <= lP) & (lP <= 10), lP > 10],[0, 1, 2, 3]))

DataTransaction['Start4'] = DataTransaction['UTCTransactionStart'].dt.hour*(60//dt)+((DataTransaction['UTCTransactionStart'].dt.minute)//dt)
DataTransaction['Stop4'] = DataTransaction['UTCTransactionStart'].dt.hour*(60//dt)+((DataTransaction['UTCTransactionStop'].dt.minute)//dt)
DataTransactionOld=DataTransaction
DataTransaction['MaxPower'] = Pow(DataTransaction['MaxPower'])
DataTransaction['Day'] = DataTransaction['UTCTransactionStart'].dt.day_of_year
DataTransaction=DataTransaction.drop(['TransactionId', 'ChargePoint', 'Connector', 'UTCTransactionStart','UTCTransactionStop', 'StartCard', 'ConnectedTime', 'ChargeTime','TotalEnergy'],axis=1)
WeekendData = DataTransaction[DataTransaction['WeekDay'].isin([5, 6])]
WeekdayData = DataTransaction[~DataTransaction['WeekDay'].isin([5, 6])]

rate=0.1
TrainDay,TestDay=train_test_split(np.unique(WeekdayData.Day), test_size=rate, random_state=0)

WeekdayTest = WeekdayData[WeekdayData['Day'].isin(TestDay)]
WeekdayTrain = WeekdayData[WeekdayData['Day'].isin(TrainDay)]
WeekdayTest=WeekdayTest.drop(['Day'],axis=1)
WeekdayTrain=WeekdayTrain.drop(['Day'],axis=1)

#WeekendTrain, WeekendTest = train_test_split(WeekendData, test_size=0.1, random_state=42)
#WeekdayTrain, WeekdayTest = train_test_split(WeekdayData, test_size=0.1, random_state=42)

#rHist=np.zeros(Nt)+0.1

dtTrain=WeekdayTrain
dtTest=WeekdayTest

def SaverHist(): ##Need to set NTime=48 and comment rHist= before running this function
    rHist24=ProdFM(fT,mu1)[24*(60//dt):]
    np.save("rHist"+str(dt)+".npy",rHist24)
    
rHist=np.load("rHist"+str(dt)+".npy")*len(dtTrain)/9

### PlotHeatmap of behaviors
#DataMeter=pd.read_excel('./elaadnl_open_ev_datasets.xlsx',sheet_name="open_metervalues")
#DataMeter['Hour'] = DataMeter['UTCTime'].dt.hour
#DataMeter['WeekDay'] = DataMeter['UTCTime'].dt.dayofweek
# DataTransaction['Hour'] = DataTransaction['UTCTransactionStart'].dt.hour
# Tab = DataTransaction.groupby(['ChargeTime4', 'Hour']).count()
# Mz=Tab.pivot_table(index='ChargeTime4', columns='Hour', values='TransactionId')
# Mz.fillna(0, inplace=True)
# plt.imshow(Mz,aspect="auto", cmap='Blues',origin='lower')
# plt.colorbar(label='number')
# plt.xlabel("Hour of the day")
# plt.ylabel("ChargeTime")
# plt.xticks(ticks=np.arange(len(Mz.columns))[::2], labels=Mz.columns[::2])
# plt.yticks(ticks=np.arange(len(Mz.index))[::4], labels=Mz.index[::4])

### Plot Consumption

#DataMeter=pd.read_excel('./elaadnl_open_ev_datasets.xlsx',sheet_name="open_metervalues")
#DataMeter['Hour'] = DataMeter['UTCTime'].dt.hour
#DataMeter['WeekDay'] = DataMeter['UTCTime'].dt.dayofweek
# total_consumption = DataMeter.groupby(['WeekDay', 'Hour'])["EnergyInterval"].sum()
# total_consumption = total_consumption.sort_index()
# plot_data = total_consumption.unstack(level=0)
# plt.plot(plot_data)
# plt.xlabel('Heure de la journée')
# plt.ylabel('Consommation d\'énergie totale')
# plt.title('Energy Consumption')
# plt.legend(title='Jour de la semaine', labels=["Monday","Tuesday","Wednesday","Thursday","Friday","Saturday","Sunday"])
# plt.show()

###Functions

cT=np.zeros((Nt,Nt,Nt,NP,Nt))
for ta in range(0,Nt):
    for tc in range(0,Nt):
        cT[ta,:,:,:,tc]+=(ta*NTime/Nt-tc*NTime/Nt)**2

fT=np.zeros((Nt,Nt,Nt,NP,Nt))
gT=np.zeros((Nt-1,Nt,Nt,NP,Nt))
for ta in range(Nt):
    print(str(ta)+"/"+str(Nt))
    for tch in range(Nt):
        for tc in range(Nt):
            for p in range(NP):
                for T in range(tc,min(Nt,tc+tch)):
                    fT[T,ta,tch,p,tc]+=p*4
                gT[:,ta,tch,p,tc]=fT[1:,ta,tch,p,tc]-fT[:-1,ta,tch,p,tc]

D=np.zeros((Nt-1,Nt))
for i in range(Nt-1):
    D[i,i]=-1
    D[i,i+1]=1
gT=np.tensordot(D,fT,axes=[-1,0])  #Discretization of the second general moment in the 1S-RCMOT problem, representing the gradient of the global consumption

S0=np.zeros(Nt-1)+(120/726)*NTime/(Nt-1) # Bounds for the gradient overall consumption

### Definition of mu1 and mu2

nu0=np.zeros((Nt,Nt,Nt,NP))

for row in dtTrain.itertuples(index=False):
    if row[4] <Nt and row[2] < Nt and row[3]<Nt and row[0]<NP:
        nu0[int(row[4]),int(row[2]),int(row[3]),int(row[0])]+=1
nu0=nu0/np.sum(nu0)

nuReal=np.zeros((Nt,Nt,Nt,NP))

for row in dtTest.itertuples(index=False):
    if row[4] <Nt and row[2] < Nt and row[3]<Nt and row[0]<NP:
        nuReal[int(row[4]),int(row[2]),int(row[3]),int(row[0])]+=1
nuReal=nuReal/np.sum(nuReal)

muReal=np.zeros((Nt,Nt,Nt,NP,Nt))

for t in range(0,Nt):
    muReal[t,:,:,:,t]=nuReal[t]

mu1=np.zeros((Nt,Nt,Nt,NP,Nt))

for t in range(0,Nt):
    mu1[t,:,:,:,t]=nu0[t]

mu2=np.zeros((Nt,Nt,Nt,NP,Nt))

for ta in range(0,Nt):
    for tco in range(0,Nt):
        for tch in range(0,Nt):
            for p in range(NP):
                for tc in range(ta,min(ta+tco-tch+1,Nt)):
                    mu2[ta,tco,tch,p,tc]+=1
                if np.sum(mu2[ta,tco,tch,p])!=0:
                    mu2[ta,tco,tch,p]=mu2[ta,tco,tch,p]*nu0[ta,tco,tch,p]/np.sum(mu2[ta,tco,tch,p])

###Useful functions

def ProdFM(f,m):
    return(np.tensordot(f,np.sum(m,1),axes=[(1,2,3,4),(0,1,2,3)]))

NominalConsumption=ProdFM(fT,mu1)*len(dtTest)
PredictedConsumption=ProdFM(fT,muReal)*len(dtTest)

# def gradJL(tOn,lbda,beta,nu_0,mu_2,rHist):
#     S=np.tensordot(lbda[tOn:]+np.dot(beta[max(tOn-1,0):],D[max(tOn-1,0):,tOn:]),fT[tOn:,tOn:,:,:,tOn:],axes=[0,0])[:,np.newaxis]+cT[tOn:,:,:,:,tOn:]
#     B=np.exp(-S/Epsilon)*mu_2
#     UB=np.sum(B,-1)
#     MuL=B*(np.divide(nu_0,UB,out=np.zeros_like(nu_0), where=UB!=0)[..., np.newaxis])
#     return(ProdFM(fT[:,tOn:,:,:,tOn:],MuL,rHist))

# def gradJB(tOn,lbda,beta,nu_0,mu_2):
#     S=np.tensordot(lbda[tOn:],fT[tOn:,tOn:,:,:,tOn:],axes=[0,0])[:,np.newaxis]+cT[tOn:,:,:,:,tOn:]+np.tensordot(beta[tOn:],gT[tOn:,tOn:,:,:,tOn:],axes=[0,0])[:,np.newaxis]+cT[tOn:,:,:,:,tOn:]
#     B=np.exp(-S/Epsilon)*mu_2
#     UB=np.sum(B,-1)
#     MuL=B*(np.divide(nu_0,UB,out=np.zeros_like(nu_0), where=UB!=0)[..., np.newaxis])
#     return(-np.sign(beta)*S0+ProdFM(gT[:,tOn:,:,:,tOn:],MuL))

# def gradJBFromMu(tOn,beta,MuL):
#     return(-np.sign(beta)*S0+ProdFM(gT[:,tOn:,:,:,tOn:],MuL))

def gradJLFromMu(tOn,MuL):
    return(ProdFM(fT[:,:,:,:,tOn:],MuL))

    # S=np.tensordot(lbda,fT,axes=[0,0])
    # B=np.tensordot(np.exp(-S/Epsilon)*mu_2,np.exp(-cT/Epsilon)[tOn:],axes=[4,1])
    # fB=np.tensordot(fT*np.exp(-S/Epsilon)*mu_2,np.exp(-cT/Epsilon)[tOn:],axes=[4,1])
    # return(ProdFM(np.divide(fB, B, out=np.zeros_like(fB), where=B!=0),mu_1[:,:,:,:,tOn:]))

def RangeR(tOn,rangeR):
    l=[]
    for i in range(tOn,Nt):
        if i in rangeR:
            l.append(i)
    return(l)

def NotRangeR(tOn,rangeR):
    l=[]
    for i in range(max(0,tOn-1),Nt-1):
        if i not in rangeR:
            l.append(i)
    return(l)

RTest=np.zeros(Nt)+np.mean(ProdFM(fT,mu1))

###Algo

lbd0=np.zeros(Nt)
bet0=np.zeros(Nt-1)

def prho(tOn,k):
    if tOn==0:
        return(min(0.5,5/((1+k)**0.7)))
    else:
        return(0.5/((1+k)**0.7))

def AlgoOnTimeAtStept(tOn,nu_0,mu_2,R,rangeR,Klimit=300,plot=True,prin=True,lbd0A=lbd0,bet0A=bet0,S0=[],NomC1=NominalConsumption,Pred=PredictedConsumption):
    lbda=lbd0A.copy()
    beta=bet0A.copy()
    k=0
    GL=lbd0.copy()+np.inf
    GB=bet0.copy()+np.inf
    rangeR2=RangeR(tOn, rangeR)
    while k<Klimit and (np.linalg.norm((GL[rangeR2]>R[rangeR2])*(GL[rangeR2]-R[rangeR2]))>0.03 ): #or np.linalg.norm(((GB-S0)*(GB>S0)+(GB+S0)*(GB<-S0))[max(0,tOn-1):])>0.10
        MuL=MuLambda(tOn, lbda, beta, nu_0, mu_2)
        GL=gradJLFromMu(tOn, MuL)
        GB=np.dot(D,GL)
        lbda[rangeR2]+=prho(tOn,k)*(GL[rangeR2]-R[rangeR2])
        lbda=np.max([lbda,lbd0],0)
        #if len(S0)>0:
            #beta[max(0,tOn-1):]+=prho(tOn,k)*(-np.sign(beta)*S0+GB)[max(0,tOn-1):]
        if plot:
            Affichage(tOn,GL,R,rangeR2,it=k,Nom=NomC1,Pred=Pred)
            # plt.plot(np.dot(D,GL))
            # plt.plot(S0)
            # plt.plot(-S0)
            # plt.show()
        if prin:
            print("Step "+str(k)+ " GL:" +str(np.linalg.norm((GL[rangeR2]>R[rangeR2])*(GL[rangeR2]-R[rangeR2])))+" GB:"+str(np.linalg.norm(((GB-S0)*(GB>S0)+(GB+S0)*(GB<-S0))[max(0,tOn-1):])))
        k+=1
    return(lbda,beta)

def VehiclesAtStepT(t):
    l=[]
    for row in dtTest[dtTest['Start4']==t].itertuples(index=False):
        if row[4] <Nt and row[2] < Nt and row[3]<Nt and row[0]<NP:
            l.append([int(row[4]),int(row[2]),int(row[3]),int(row[0])])
    return(l)

def InitT(tOn,listX,N):
    nu_0=np.zeros((Nt,Nt,Nt,NP))
    mu_2=np.zeros((Nt,Nt,Nt,NP,Nt-tOn))
    N1t=len(listX)
    N2t=np.sum(nu0[tOn+1:])*N
    nu_0[0:tOn+1]=nu0[0:tOn+1]*0
    nu_0[tOn+1:]=nu0[tOn+1:]*N2t/(N1t+N2t)
    for V in listX:
        nu_0[V[0],V[1],V[2],V[3]]+=1/(N1t+N2t)
    for ta in range(0,Nt):
        for tco in range(0,Nt):
            for tch in range(0,Nt):
                for p in range(NP):
                    for tc in range(max(0,ta-tOn),min(ta-tOn+tco-tch+1,Nt-tOn)):
                        mu_2[ta,tco,tch,p,tc]+=1
                    if np.sum(mu_2[ta,tco,tch,p])!=0:
                        mu_2[ta,tco,tch,p]=mu_2[ta,tco,tch,p]*nu_0[ta,tco,tch,p]/np.sum(mu_2[ta,tco,tch,p])
    return(nu_0,mu_2)

def AlgoOnTime(R2,rangeR,N=len(dtTrain)*rate/(1-rate),Klimit=75,S0=[]):
    lbda_star=lbd0.copy()
    beta_star=bet0.copy()
    muFinal=np.zeros((Nt,Nt,Nt,NP,Nt))
    t0=time.time()
    listX=[]
    for tOn in range(Nt):
        print(str((tOn*dt)//60)+"h"+str((tOn*dt)%60))
        listX+=VehiclesAtStepT(tOn)
        N1t=len(listX)
        N2t=np.sum(nu0[tOn+1:])*N
        NomC1=NominalConsumption/(N1t+N2t)
        Pred=PredictedConsumption/(N1t+N2t)
        r=(R2-ProdFM(fT, muFinal)-rHist)/(N1t+N2t)
        nu_0,mu_2=InitT(tOn, listX,N)
        lbda_star,beta_star=AlgoOnTimeAtStept(tOn, nu_0, mu_2, r, rangeR,prin=False,plot=True,Klimit=Klimit,lbd0A=lbda_star,bet0A=beta_star,S0=S0,NomC1=NomC1,Pred=Pred) #range(max(np.min(rangeR),tOn),np.max(rangeR))
        print("Running Time MuComputing: "+str(round(time.time()-t0))+"s")
        MuL=MuLambda(tOn, lbda_star, beta_star, nu_0, mu_2)
        listXNew=[]
        for V in listX:
            if np.sum(MuL[V[0],V[1],V[2],V[3]])!=0:
                tc=np.random.choice(range(Nt-tOn),p=MuL[V[0],V[1],V[2],V[3]]/np.sum(MuL[V[0],V[1],V[2],V[3]]))
                if tc==0:
                    muFinal+=np.array([[[[[t1==V[0] and tch1==V[2] and tco1==V[1] and p1==V[3] and tc1==tc+tOn for tc1 in range(Nt)] for p1 in range(NP)] for tch1 in range(Nt) ]for tco1 in range(Nt)]for t1 in range(Nt)])
                else:
                    listXNew.append(V)
            else:
                tc=np.random.choice(range(Nt-tOn))
                print("Error")
        listX=listXNew
        Affichage(tOn,ProdFM(fT, muFinal)+rHist, R2, rangeR,Pred=PredictedConsumption+rHist,Nom=NominalConsumption+rHist)
        print("Running Time Vehicles Plugging: "+str(round(time.time()-t0))+"s")
        np.save("MuFinal.npy",muFinal)
    return(muFinal)

###Display

def Affichage(tOn,G,R2,rangeR,it="",Nom=NominalConsumption,Pred=PredictedConsumption):
    plt.plot(Td,Nom,label='Nominal consumption')
    plt.plot(Td,Pred,label="Predicted consumption")
    plt.plot(Td,G,label="Optimized consumption")
    plt.plot(Td[rangeR],R2[rangeR],label="Constraint")
    plt.xlabel("Time")
    plt.ylabel("Aggregated consumption")
    plt.grid()
    plt.legend()
    if it!="":
        plt.title("Time: "+str(tOn)+ " Iteration: "+str(it))
    plt.show()

def MuLambda(tOn,lbda,beta,nu_0,mu_2):
    S=np.tensordot(lbda[tOn:]+np.dot(beta[max(tOn-1,0):],D[max(tOn-1,0):,tOn:]),fT[tOn:,:,:,:,tOn:],axes=[0,0])[:,np.newaxis]+cT[:,:,:,:,tOn:]
    B=np.exp(-S/Epsilon)*mu_2
    UB=np.sum(B,-1)
    MuL=B*(np.divide(nu_0,UB,out=np.zeros_like(nu_0), where=UB!=0)[..., np.newaxis])
    return(MuL)   

def SaveWithTD(Name,Arr):
    Mat=np.transpose([Td,Arr])
    np.savetxt(Name,Mat)
     