import numpy as np
import matplotlib.pyplot as plt
import os
import time

##INPUTS

Ni=2
Nt=101
T=10
Ns=25
Td=np.linspace(0,T,Nt)
Sd=np.linspace(0,1,Ns)
dt=Td[1]-Td[0]
ds=Sd[1]-Sd[0]
eps=0.01
lbd0=np.zeros(Nt)

vita=2/5
dropout=0.75

def b(i,s):
    if i==1:
        if s<dropout:
            return(vita)
        else:
            return(vita*(1-(s-dropout)/(1-dropout)))
    if i==0:
        return(0)
    if i==2:
        return(-b(1,1-s))

c1=0
c2=0

def c(i,t):
    if i==1:
        return(c2)
    elif i==0:
        return(0)
    else:
        return(-c2)

def Sinit():
    return(min(1,0.15+abs(np.random.normal(0,0.2))))

b0=0.0035
Dd2=b0+(0.2-b0)*np.exp(-Td*0.2803635831824254)

CoefTransfert=np.ones((Ni,Ni))


def indice(t,Td):
    return(min(int(t*(len(Td)-1)/Td[-1]),len(Td)-1))

def maxL(L,Td,Sd,t0,tf):
    i0=indice(t0,Td)
    iF=indice(tf,Td)
    if i0<iF:
        return(np.max(L[:,:,i0:iF]))
    else:
        return(np.max(L))

def nextS(i,s,Ti,Tf):
    if i==1:
        if s<dropout:
            if s+vita*(Tf-Ti)<dropout:
                return(s+vita*(Tf-Ti))
            else:
                S=dropout
                Tint=Ti+(dropout-s)/vita
        else:
            S=s
            Tint=Ti
        return(1-(1-S)*np.exp(vita*(-Tf+Tint)/(1-dropout)))
    elif i==0:
        return(s)
    elif i==2:
        return(1-nextS(1,1-s,Ti,Tf))

def Jump2(i,A,t,alphaL,s,Td,Sd):
    a=t
    J=[]
    S=s
    while J==[]:
        L=maxL(alphaL,Td,Sd,a,a+A)
        for j in range(Ni):
            m=np.random.poisson(L*A)
            if m>0:
                for k in range(1,m+1):        
                    uk=np.random.uniform(a,a+A)
                    vk=np.random.uniform(0,L)
                    ind=indice(nextS(i,S,a,uk),Sd)
                    if vk<alphaL[i,j,indice(uk,Td),ind] and i!=j:
                        J.append([j,uk])
        if J==[] and a<T:
            S=nextS(i,S,a,a+A)
            a=a+A
        else:
            if J!=[]:
                indi=np.argmin(J,0)[1]
                j,t=J[indi]
            else:
                j,t=i,T
            return(min(T,t),j,nextS(i,S,a,min(t,T)))


def PDMP2(i,A,T,alphaL,s):
    t0,I,S=Jump2(i,A,0,alphaL,s,Td,Sd)
    Pk=[[0,i,s],[t0,I,S]]
    while Pk[-1][0]<T:
        t,I,S=Jump2(I,A,Pk[-1][0],alphaL,S,Td,Sd)
        Pk.append([t,I,S])
    return(np.array(Pk))

def PDMP2D(P):
    p=np.zeros((Nt,3))
    p[0]=P[0]
    ind=0
    for t in range(1,Nt):
        t0,i,s=p[t-1]
        if t0+dt<P[ind][0]:
            s=nextS(i,s,t0,t0+dt)
            p[t]=[t0+dt,i,s]
        else:
            while ind<len(P) and t0+dt>=P[ind][0]:
                s=P[ind][2]
                i=P[ind][1]
                ind=ind+1
            s=nextS(i,s,P[ind-1][0],t0+dt)
            p[t]=[t0+dt,i,s]
    return(p)

def Ilaw():
    return(np.random.binomial(1,np.mean(Dd2)/0.4))

def conso(P):
    l=[]
    for t in range(len(P[:,0])):
        l.append(b(P[t,1],P[t,2]))
    return(l)

def IPDMP2(n,Ilaw,Sinit,T,alphaL,A,Nt):
    Sum=np.zeros((Ni,Nt))
    for k in range(n):
        s=Sinit()
        I=Ilaw()
        p=PDMP2(I, A, T, alphaL, s)
        P=PDMP2D(p)
        cons=conso(P)
        for t in range(Nt):
            Sum[int(P[t][1]),t]+=cons[t]/n
    return(Sum)

def Schema(Nt,Ns,b,c,lbda,psi):
    ##INITIALISATION
    
    Td=np.linspace(0,T,Nt)
    Sd=np.linspace(0,1,Ns)
    dt=Td[1]-Td[0]
    ds=Sd[1]-Sd[0]

    u=np.zeros((Ni,Nt,Ns))
    
    for i in range(Ns):
        for lev in range(Ni):
            u[lev][Nt-1][i]=psi(Sd[i])

    bd=np.zeros((Ni,Ns))
    cd=np.zeros((Ni,Nt))

    for i in range(Ni):
        for t in range(Nt):
            cd[i,t]=c(i,Td[t])

    for i in range(Ni):
        for s in range(Ns):
            bd[i,s]=b(i,Sd[s])
    
    for k in range(Nt-2,-1,-1):
        for l in range(Ns-1,-1,-1):
            for i in range(Ni):
                s=0
                for j in range(Ni):
                    s+=(1/CoefTransfert[i,j])*((max(0,u[i,k+1,l]-u[j,k+1,l]))**2)/2
                if np.min(bd[i])<0:
                    if l==0:
                        u[i,k,l]=u[i,k+1,l]+dt*(cd[i,k]*bd[i,l]+lbda[k]*bd[i,l]-s)
                    else:
                        u[i,k,l]=u[i,k+1,l]+dt*(bd[i,l]*(u[i,k+1,l]-u[i,k+1,l-1])/ds+cd[i,k]*bd[i,l]+lbda[k]*bd[i,l]-s)
                else:
                    if l==Ns-1:
                        u[i,k,l]=u[i,k+1,l]+dt*(cd[i,k]*bd[i,l]+lbda[k]*bd[i,l]-s)
                    else:
                        u[i,k,l]=u[i,k+1,l]+dt*(bd[i,l]*(u[i,k+1,l+1]-u[i,k+1,l])/ds+cd[i,k]*bd[i,l]+lbda[k]*bd[i,l]-s)
    return(u)

def rho(K):
    return(min(30,500/(1+K)))

def alphaFromU(u):
    alpha=np.zeros((Ni,Ni,Nt,Ns))
    for i in range(Ni):
        for j in range(Ni):
            alpha[i,j]=positiv(u[i]-u[j])
    return(alpha)

coefA=10**3

def solv(lbda,D):
    return((2*coefA*D+lbda)/(2*coefA))

def algo(eps,n,Nt, Ns, b, c, psi,Ilaw,A,Dd,plot=True,prin=True):
    K=0
    t0=time.time()
    lbdk=lbd0.copy()
    u=Schema(Nt,Ns,b, c, lbdk, psi)
    Sum=0
    vStar=eps*2*np.sqrt(Nt)
    while np.linalg.norm(np.sum(Sum,0)-vStar)/np.sqrt(Nt)>eps:
        u=Schema(Nt,Ns,b, c, lbdk, psi)
        alphaL=alphaFromU(u)
        Sum=IPDMP2(n, Ilaw, Sinit, T, alphaL, A, Nt)
        vStar=solv(lbdk,Dd)
        lbdk=lbdk+rho(K)*(np.sum(Sum,0)-vStar)
        K=K+1
        np.save("lbdaSaveNbModes"+str(Ni)+".npy",lbdk[-1])
        if prin:
            print(np.linalg.norm(np.sum(Sum,0)-vStar)/np.sqrt(Nt))
            print("Iteration "+ str(K)+" au bout de " +str(time.time()-t0)[:5]+"s")
    return(lbdk)

def positiv(x):
    return((np.abs(x)+x)/2)

def psi(s):
    return(30*positiv(1-np.exp(5*((s-0.75)))))

def unif():
    return(np.random.uniform(0,0.7))

bd=np.zeros((Ni,Ns))
for i in range(Ni):
    for s in range(Ns):
        bd[i,s]=b(i,Sd[s])

cd=np.zeros((Ni,Nt))

for i in range(Ni):
    for t in range(Nt):
        cd[i,t]=c(i,Td[t])

if np.max(bd)*dt/ds>1:
    print("Numerical Scheme may not converge: "+ str(np.max(bd)*dt/ds))

###VISUALISATION END

def plotAlphaFin(u,i,j,log=False):
    a=positiv(u[i]-u[j])
    if log:
        a=np.log(a)
    hm=plt.imshow(np.transpose(a), cmap='Blues',interpolation="nearest",extent=[0,T,0,1],origin='lower',aspect='auto')
    plt.colorbar(hm)
    plt.xlabel('Time')
    plt.ylabel('Battery')

alphaLTest=np.zeros((Ni,Ni,Nt,Ns))
alphaLTest[0,1]+=0.5

alphaLTest2=np.zeros((Ni,Ni,Nt,Ns))
alphaLTest2[0,1,:,0] = 0.5
alphaLTest2[1,0,:,10:] = 0.5

alphaLTest3=np.zeros((Ni,Ni,Nt,Ns))

for t in range(Nt):
    alphaLTest3[0,1,t,:]=alphaLTest3[0,1,t,:]+2*Td[t]    

def dirac0():
    return(0)

def dirac14():
    return(0.25)

def EstimJ(n,Ilaw,Sinit,T,alphaL,A,Dd,plot=False):
    SumA=0
    SumP=0
    SumD=0
    ISum=np.zeros((Ni,Nt))
    for k in range(n):
        s=Sinit()
        I=Ilaw()
        P=PDMP2D(PDMP2(I,A,T,alphaL,s))
        cons=conso(P)
        for t in range(Nt):
            ISum[int(P[t][1]),t]+=cons[t]/n
            for j in range(Ni):
                SumA+=dt*0.5*(alphaL[int(P[t][1]),j,t,indice(P[t][2],Sd)])**2
        SumP+=psi(P[-1][2])
    SumD=np.sum(coefA*(np.sum(ISum,0)-Dd)**2)*dt
    if plot:
        plt.pie([SumA/n,SumP/n,SumD],labels=["Cost alpha: "+str(round(SumA/n,2)),"Cost final: "+str(round(SumP/n,2)),"Cost Distance to Signal: "+str(round(SumD,2))])
        plt.title("Total cost: "+str(round(SumA/n+SumP/n+SumD,2)))
        plt.show()
    return(SumA/n+SumP/n+SumD,SumA/n,SumP/n,SumD)

def EstimP(n,I,Sinit,T,alphaL,A,lbda,D,plot=False):
    Sum=0
    IP=IPDMP2(n,I,Sinit,T,alphaL,A,Nt)
    for i in range(Ni):
        for t in range(Nt):
            Sum+=dt*lbda[i,t]*(IP[i,t]-D(i,Td[t]))
    return(Sum)

def Aggrandir(l,N):
    n=len(l)
    L=[]
    for i in range(N-1):
        i2=int(i*(n-1)/(N-1))
        L.append(l[i2]+(l[i2+1]-l[i2])*(i*(n-1)/(N-1)-i2))
    return(np.array(L+[l[-1]]))

def NpyToTxt(nameListNpy,AddTime=False):
    if AddTime:
        L=np.load(nameListNpy)
        l=np.zeros((2,len(L)))
        l[0]=np.linspace(0,T,len(L))
        l[1]=L
        l=np.transpose(l)
    else:
        l=np.load(nameListNpy)
    np.savetxt(nameListNpy[:-3]+"txt",l)
    print("Done")
