import Simulation
import DamageModel
import numpy as np
import MILPs.MILPModel as mMeta
import time
from sklearn.ensemble import RandomForestRegressor
import joblib

#np.random.seed(2)
from MetaEnvironment import Environment
from QLearning.model import *

import MILPs.MILP_MaximizeDiversity as mDiv
from matplotlib import pyplot as pl
import MILPs.MILP_MaxUncertainty as mUn
from Simulation import SimulationAllDamage


def calcDiff(prevA,prevO,a,e,numS=1):

    o=e.simDamage(a)
    a=a.reshape(4,-1)
    o=o.reshape(11,-1)
    allDiffA=np.zeros((numS,8))
    allDiffO = np.zeros((numS, 22))
    dA1=(prevA[:,0].reshape(-1,1)-a).reshape(-1,4)
    dA2=(prevA[:,1].reshape(-1,1)-a).reshape(-1,4)

    allDiffA=np.hstack((dA1,dA2))
    allDiffO=np.hstack(((prevO[:,0].reshape(-1,1)-o).reshape(-1,11),(prevO[:,1].reshape(-1,1)-o).reshape(-1,11)))

    return np.hstack((allDiffA,allDiffO))

def calcU(a,e):
    u=e.getUncertainty(a)
    return u

def findBestA(lal,e,step):
    numSamples=30

    r = np.array([10, 10, 100, 50, 50, 50, 50, 10, 10, 300]).reshape(10, -1)*3.5
    xref = np.array([5, 0, 0, 0, 0, 0, 0, 0, 0, 1000]).reshape(10, -1)
    upperBound=np.tile((r+xref),(1,numSamples*numSamples*numSamples*numSamples))
    lowerBound = np.tile((xref-r), (1, numSamples*numSamples*numSamples*numSamples))
    elevator = np.atleast_2d(np.linspace(-.1, .1, numSamples)).T
    thrust = np.atleast_2d(np.linspace(7000, 7030, numSamples)).T
    aileron = np.atleast_2d(np.linspace(-.1, .1, numSamples)).T
    rudder = np.atleast_2d(np.linspace(-.1, .1, numSamples)).T
    possA=np.array(np.meshgrid(elevator, thrust, aileron,rudder))

    possA=possA.reshape(numSamples,numSamples,numSamples,4,numSamples)
    possA=possA.T.T.reshape(4,-1)
    out = e.sim.simulateDamage(np.tile(e.sim.currentState, (1,numSamples*numSamples*numSamples*numSamples)),possA)
    out2 = e.sim.simulateDamage(out,
                                possA)
    out2 = e.sim.simulateDamage(out2,
                                possA)
    out2 = e.sim.simulateDamage(out2,
                                possA)

    out2=np.vstack((out2[0:9,:],out2[10,:]))

    indicator = np.logical_and(out2 < upperBound,out2 > lowerBound)
    #indicator=np.invert(indicator)
    indicator2=np.all(indicator,axis=0)
    #indicator2 = indicator2.reshape(numSamples, numSamples, numSamples, numSamples)
    indicator3 = np.where(indicator2)
    possA=possA[:,indicator3]
    possA=possA.reshape(4,-1)

    uncertainty = calcU(possA, e)
    diff = calcDiff(e.d.actions[:, -2:], e.d.states[:, -2:], possA, e)
    #possA = possA.reshape(4, numSamples, numSamples, numSamples, numSamples)
    #possA = possA.T.reshape(-1, 4)
    allFeatures = np.concatenate((possA.reshape(-1,4), diff.reshape(-1, 30), uncertainty.reshape(-1,1),np.tile(step,(possA.shape[1],1))),axis=1)
    allFeatures=allFeatures.reshape(-1,36)
    r=lal.predict(allFeatures)
    index=np.argmax(r)
    print("Index",i)
    print(possA[:,0])
    a=possA[:,index]

    #index=np.argmax(np.array(allR))

    return a




def calcValue(critic, state, prevSA, u0):
    numA = 100
    elevator1 = np.random.uniform(u0[0] - 5, u0[0] + 5, (1, numA))
    thrust1 = np.random.uniform(u0[1] - 100, u0[1] + 100, (1, numA))
    aileron1 = np.random.uniform(u0[2] - 5, u0[2] + 5, (1, numA))
    rudder1 = np.random.uniform(u0[3] - 5, u0[3] + 5, (1, numA))

    elevator2 = np.random.uniform(u0[0] - 5, u0[0] + 5, (1, numA))
    thrust2 = np.random.uniform(u0[1] - 100, u0[1] + 100, (1, numA))
    aileron2 = np.random.uniform(u0[2] - 5, u0[2] + 5, (1, numA))
    rudder2 = np.random.uniform(u0[3] - 5, u0[3] + 5, (1, numA))

    randomActions = np.vstack((elevator1, thrust1, aileron1, rudder1, elevator2, thrust2, aileron2, rudder2))
    randomActions = randomActions.transpose(1, 0)
    randomActions = torch.from_numpy(randomActions).float()

    states = state.repeat(numA, 1)
    prevSAs = prevSA.repeat(numA, 1, 1)

    values, _ = critic.forward(states, randomActions, prevSAs)
    return np.mean(values.detach().numpy())

def uncertainty_test(damageNum,n):


    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damageNum, simAll=True)
    stateCurrent,uCurrent,error=e.genInitial()
    xk = stateCurrent
    u0 = uCurrent

    allErrorDiv = []
    allErrorDiv.append(0)
    #allErrorDiv.append(error)
    error_prev=1500
    prevReward = 0
    totalTime=0
    for i in range(n):
        start=time.time()
        a = mUn.solve_br(e.sim.Anominal, e.sim.Bnominal, e.d.Asigma, e.d.Bsigma, xk=xk, u0=u0,A1=e.d.A1,A2=e.d.A2,A3=e.d.A3,B1=e.d.B1,B2=e.d.B2,B3=e.d.B3)
        a = np.asarray(a)

        nextState, r1, error1 = e.step(a[0:4].reshape(4, 1))
        nextState, r1, error2 = e.step(a[4:8].reshape(4, 1))
        nextState, r1, error3 = e.step(a[8:12].reshape(4, 1))
        end=time.time()

        prevReward += (error_prev - error3) / error_prev
        allErrorDiv.append((800-error3)/800)
        error_prev=error3
        totalTime+=end-start



        xk = nextState
        u0 = a[8:12]
    print("totalTimeU",totalTime/5.0)
    return allErrorDiv,totalTime/5.0


def lambda_test(damageNum,n):
    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damageNum, simAll=True)
    stateCurrent, uCurrent, error = e.genInitial()

    xk = stateCurrent
    u0 = uCurrent
    allTime = 0
    h = 30
    critic = Critic(352, 8, h2=h)
    critic.load_state_dict(torch.load('InfoGainModels/critic528_1.pkl'))
    critic = critic.cuda()
    fC2 = critic.fc2.weight.data.cpu().numpy()
    fC2 = fC2.reshape(h, -1)
    fB2 = critic.fc2.bias.data.cpu().numpy()
    fC3 = critic.fc3.weight.data.cpu().numpy()
    fC3 = fC3.reshape(h)
    fB3 = critic.fc3.bias.data.cpu().numpy()

    allErrorMeta = []
    allErrorMeta.append(800)
    # allErrorMeta.append(error)
    error_prev = error
    prevReward = 0
    allProb=0
    for i in range(n):
        ZPrime = e.encoder.encode(e.d.states, e.d.actions, e.d.deltS)
        prevS_A = np.vstack((e.d.states, e.d.actions))
        _, Z = critic(ZPrime, torch.zeros((ZPrime.shape[0], 8)).cuda(), prevS_A.reshape(1, -1, 15))
        Z = Z.detach().cpu().numpy().reshape(50, -1)
        prevStates = e.d.states
        prevActions = e.d.actions
        prevS_A = np.vstack((prevStates, prevActions))
        prevS_A = prevS_A.reshape(1, -1, 15)
        # value = calcValue(critic, ZPrime,prevS_A,u0)

        a,prob = mMeta.solve_br(e.sim.Anominal, e.sim.Bnominal, e.d.Asigma, e.d.Bsigma, xk=xk, u0=u0, fC2=fC2, fC3=fC3,
                           fB2=fB2, fB3=fB3, Z=Z, value=0,lmbda=0,mu=5)
        a = np.asarray(a)

        nextState, r1, error1 = e.step(a[0:4].reshape(4, 1))
        nextState, r1, error2 = e.step(a[4:8].reshape(4, 1))
        nextState, r1, error3 = e.step(a[8:12].reshape(4, 1))

        prevReward += (error_prev - error3) / error_prev
        allErrorMeta.append(error1)
        error_prev = error3

        xk = nextState
        u0 = a[8:12]
        allProb += prob

    return allErrorMeta, allProb / 5.0

def meta_test(damageNum,n):
    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damageNum, simAll=True)
    stateCurrent,uCurrent,error=e.genInitial()

    xk = stateCurrent
    u0 = uCurrent
    allTime = 0
    h=30
    critic = Critic(352, 8,h2=h)
    critic.load_state_dict(torch.load('InfoGainModels/critic528_1.pkl'))
    critic=critic.cuda()
    fC2 = critic.fc2.weight.data.cpu().numpy()
    fC2 = fC2.reshape(h, -1)
    fB2 = critic.fc2.bias.data.cpu().numpy()
    fC3 = critic.fc3.weight.data.cpu().numpy()
    fC3 = fC3.reshape(h)
    fB3 = critic.fc3.bias.data.cpu().numpy()



    allErrorMeta = []
    allErrorMeta.append(0)
    #allErrorMeta.append(error)
    error_prev=error
    prevReward=0

    allProb=0
    totalTime=0
    for i in range(n):
        ZPrime = e.encoder.encode(e.d.states, e.d.actions, e.d.deltS)
        prevS_A=np.vstack((e.d.states,e.d.actions))
        _, Z = critic(ZPrime, torch.zeros((ZPrime.shape[0], 8)).cuda(),prevS_A.reshape(1,-1,15))
        Z = Z.detach().cpu().numpy().reshape(50, -1)
        prevStates=e.d.states
        prevActions=e.d.actions
        prevS_A = np.vstack((prevStates, prevActions))
        prevS_A=prevS_A.reshape(1,-1,15)
        #value = calcValue(critic, ZPrime,prevS_A,u0)
        start=time.time()
        a,prob = mMeta.solve_br(e.sim.Anominal, e.sim.Bnominal, e.d.Asigma, e.d.Bsigma, xk=xk, u0=u0, fC2=fC2, fC3=fC3,fB2=fB2,fB3=fB3, Z=Z,value=0)
        a = np.asarray(a)
        print("METAACTION",a)

        nextState, r1, error1 = e.step(a[0:4].reshape(4, 1))
        nextState, r1, error2= e.step(a[4:8].reshape(4, 1))
        nextState, r1, error3 = e.step(a[8:12].reshape(4, 1))
        print("META1",a[0:4])
        print("META2", a[4:8])
        print("META3", a[8:12])
        end=time.time()
        prevReward += (error_prev - error3) / error_prev
        allErrorMeta.append((800-error3)/800)
        error_prev=error3


        xk = nextState
        u0 = a[8:12]
        allProb+=prob
        totalTime+=end-start
    print("totalTimeM", totalTime / 5.0)
    return allErrorMeta,totalTime/5.0

def diversity_test(damageNum,n):


    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damageNum, simAll=True)
    stateCurrent,uCurrent,error=e.genInitial()
    xk = stateCurrent
    u0 = uCurrent
    u1=u0

    allErrorDiv = []
    allErrorDiv.append(0)
    #allErrorDiv.append(error)
    prevReward=0
    error_prev=error
    totalTime=0
    for i in range(n):
        start=time.time()
        a = mDiv.solve_br(e.sim.Anominal, e.sim.Bnominal, e.d.Asigma, e.d.Bsigma, xk=xk, u0=u0)
        a = np.asarray(a)
        print("actionDiv",a)

        nextState, r1, error1 = e.step(a[0:4].reshape(4, 1))
        nextState, r1, error2 = e.step(a[4:8].reshape(4, 1))
        nextState, r1, error3 = e.step(a[8:12].reshape(4, 1))

        prevReward += (error_prev - error3) / error_prev
        allErrorDiv.append((800-error3)/800)

        error_prev=error3



        xk = nextState
        u0 = a[8:12]
        u1=u0
        end=time.time()
        totalTime+=end-start
    totalTime += end - start
    print("totalTimeU", totalTime / 5.0)
    return allErrorDiv,totalTime/5.0

def lal_test(damageNum, n):

    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damageNum, simAll=True)
    stateCurrent, uCurrent, error = e.genInitial()
    xk = stateCurrent
    u0 = uCurrent
    lalModel1 = joblib.load('lalPlaneNN.sav')
    allErrorDiv = []
    allErrorDiv.append(0)
    # allErrorDiv.append(error)
    prevReward = 0
    error_prev = error
    totalTime = 0
    for i in range(n):
        start = time.time()
        a =findBestA(lalModel1,e,i)
        nextState, r1, error1 = e.step(a.reshape(4, 1))
        #a = findBestA(lalModel1, e, i)
        #nextState, r1, error2 = e.step(a.reshape(4, 1))
        #a = findBestA(lalModel1, e, i)
        print("action",a)
        nextState, r1, error3 = e.step(a.reshape(4, 1))

        prevReward += (error_prev - error3) / error_prev
        allErrorDiv.append((800-error3)/800)

        error_prev = error3
        end = time.time()
        totalTime += end - start
    totalTime += end - start
    print("totalTimeU", totalTime / 5.0)
    return allErrorDiv, totalTime / 5.0

    #mod.print_information()    return allErrorDiv

numSteps=5
numIters=10
metaPerformance=np.zeros((numIters,numSteps+1))
divPerformance=np.zeros((numIters,numSteps+1))
UPerformance=np.zeros((numIters,numSteps+1))
LPerformance=np.zeros((numIters,numSteps+1))
lalPerformance=np.zeros((numIters,numSteps+1))
diff_meta_div=np.zeros((numIters,numSteps+1))
done=0
allProbMeta=0
countMetaTotal=0

allProblambda=0
countTotallambda=0
totalTimeU=0
totalTimeD=0
totalTimeM=0
totalTimeLal=0
for i in range(0,numIters):
    np.random.seed(i)
    d1 = np.random.randint(0, 2)
    d2 = np.random.randint(0, 4)
    p1 = np.random.uniform(low=0, high=1)
    p2 = np.random.uniform(low=.75, high=1)
    damage=np.array([d1,         d2 ,        p1, p2])
    print("DAMAGEINTEST",damage)

    try:
        divError,t=diversity_test(damage,numSteps)
        totalTimeD += t
        print("DIVERSITY Succeeded")
        done=1
    except Exception:
        print("DIVERSITY FAILED")
        done=0


    if done==1:
        try:
            UError,t = uncertainty_test(damage, numSteps)
            totalTimeU+=t
            #allProblambda+=prob
            #countTotallambda+=1
            print("Uncertainty Succeeded")
            done=1
        except Exception:
            print("Uncertainty FAILED")
            done=0


    if done==1:
        try:
            LError,_ = lambda_test(damage, numSteps)
            #allProblambda+=prob
            #countTotallambda+=1
            print("Lambda Succeeded")
            done=1
        except Exception:
            print("Lambda FAILED")
            done=0


    if done==1:
        try:
            metError,t=meta_test(damage,numSteps)
            totalTimeM += t
            #allProbMeta+=prob
            countMetaTotal+=1
            print("META Succeeded")
            done=1
        except Exception:
            print("META FAILED")
            done=0

    if done==1:
        lalError,t=lal_test(damage,numSteps)
        totalTimeLal += t






    if done==1:
        print(metError)
        print(divError)
        print(UError)
        print(lalError)
        difference=np.asarray(UError)-np.asarray(metError)
        diffD=(np.asarray(divError)-np.asarray(metError))/np.asarray(metError)
        #diffU = (np.asarray(UError) - np.asarray(metError)) / np.asarray(metError)
        print(difference)
        metaPerformance[i,:]=metError
        divPerformance[i,:]=divError
        UPerformance[i, :] = UError
        LPerformance[i, :] = LError
        lalPerformance[i, :] = lalError
        diff_meta_div[i,:]=difference
        print('FINISHED')
    else:
        i=i-1
    done=0
print('timeU',totalTimeU/numIters)
print('timeD',totalTimeD/numIters)
print('timeM',totalTimeM/numIters)
print('timeLAL',totalTimeLal/numIters)
#print("ProbMeta",allProbMeta/float(countMetaTotal))
#print("ProbLam",allProblambda/float(countTotallambda))
#print(metaPerformance)
meanM=np.mean(metaPerformance,axis=0)
stdM=np.std(metaPerformance,axis=0)/np.sqrt(numIters)
print("meanMeta",meanM)

#print(divPerformance)
meanD=np.mean(divPerformance,axis=0)
stdD=np.std(divPerformance,axis=0)/np.sqrt(numIters)
print("meanD",meanD)

#print(UPerformance)
meanU=np.mean(UPerformance,axis=0)
stdU=np.std(UPerformance,axis=0)/np.sqrt(numIters)
print("meanU",meanU)

meanL=np.mean(LPerformance,axis=0)
stdL=np.std(LPerformance,axis=0)/np.sqrt(numIters)
print("meanU",meanL)

meanLal=np.mean(lalPerformance,axis=0)
stdLal=np.std(lalPerformance,axis=0)/np.sqrt(numIters)
print("meanU",meanLal)

#np.savetxt(r"C:\Users\mschr\CORELABProjects\Meta-Learning-Aircraft\Info_Gain_Results\ParamMeta.csv", np.vstack((meanM,stdM)), fmt='%5s',delimiter=",")
#np.savetxt(r"C:\Users\mschr\CORELABProjects\Meta-Learning-Aircraft\Info_Gain_Results\ParamUncertainty.csv", np.vstack((meanU,stdU)), fmt='%5s',delimiter=",")
#np.savetxt(r"C:\Users\mschr\CORELABProjects\Meta-Learning-Aircraft\Info_Gain_Results\ParamDiversity.csv", np.vstack((meanD,stdD)), fmt='%5s',delimiter=",")
#np.savetxt(r"C:\Users\mschr\CORELABProjects\Meta-Learning-Aircraft\Info_Gain_Results\ParamDiversity.csv", np.vstack((meanD,stdD)), fmt='%5s',delimiter=",")
#np.savetxt(r"C:\Users\mschr\CORELABProjects\Meta-Learning-Aircraft\Info_Gain_Results\ParamLal.csv", np.vstack((meanLal,stdLal)), fmt='%5s',delimiter=",")

#pl.plot(meanM,color='g')
#pl.fill_between(range(numSteps+1), meanM-stdM, meanM+stdM, alpha = 0.5,color='g')
#pl.set_ylim([-50,50])


pl.plot(meanD,color='r')
pl.fill_between(range(numSteps+1), meanD-stdD, meanD+stdD, alpha = 0.5,color='r')
#pl.set_ylim([-50,50])

pl.plot(meanU,color='g')
pl.fill_between(range(numSteps+1), meanU-stdU, meanU+stdU, alpha = 0.5,color='g')
#pl.set_ylim([-50,50])

pl.plot(meanM,color='b')
pl.fill_between(range(numSteps+1), meanM-stdM, meanM+stdM, alpha = 0.5,color='b')

"""pl.plot(meanL,color='y')
pl.fill_between(range(numSteps+1), meanL-stdL, meanL+stdL, alpha = 0.5,color='y')
"""

pl.plot(meanLal,color='orange')
pl.fill_between(range(numSteps+1), meanLal-stdLal, meanLal+stdLal, alpha = 0.5,color='orange')


#pl.set_ylim([-50,50])

pl.show()


print(diff_meta_div)
meanBoth=np.mean(diff_meta_div,axis=0)
stdBoth=np.std(diff_meta_div,axis=0)
stdBoth=stdBoth/3.16
print(meanBoth)

pl.plot(meanBoth,color='r')
pl.fill_between(range(numSteps+1), meanBoth-stdBoth, meanBoth+stdBoth, alpha = 0.5,color='r')
#pl.set_ylim([-50,50])
pl.plot(range(numSteps+1),np.zeros(numSteps+1),'--',color='b')
pl.title("DIfference")
pl.show()