import torch
from MetaEnvironment import Environment
from Bao.MetaBayesian import BayesianEnvironment
from QLearning.model import *
import MILPs.MILPModel_2layer as MILP
import MILPs.infoGainMILP_ChanceConstraints as MILPSafety
from Bao import BayesianAcquisitionFunction
from matplotlib import pyplot as pl
import MetaLearnBO.run_experiments as MBO
import time
from scipy.stats import norm

from sklearn.ensemble import RandomForestRegressor
import joblib

def calcDiff(prevA,prevO,a,gp):
    o=gp.predict(a)
    allDiffA=np.zeros((1,8))
    allDiffO = np.zeros((1, 8))
    for i in range(prevA.shape[0]):
        allDiffA[0,i]=prevA[i]-a
        allDiffO[0,i]=prevO[i]-o
    return np.hstack((allDiffA,allDiffO))

def calcU(gp,a):
    _,sigma=gp.predict(a,return_std=True)

    return np.array([sigma])

def findBestA(lal,actions,out,gp,step):
    numSamples=50
    possA=np.atleast_2d(np.linspace(0, 4, numSamples)).T
    allR=[]
    for i in range(numSamples):
        uncertainty=calcU(gp,possA[i].reshape(1,-1))
        diff=calcDiff(actions,out,possA[i].reshape(1,-1),gp)
        allFeatures= np.hstack((possA[i].reshape(1,1), diff, uncertainty, np.array([step]).reshape(1, 1)))
        r=lal.predict(allFeatures)
        allR.append(r)

    i=np.argmax(np.array(allR))
    return possA[i]


action_dim = 1


e = Environment(optimizeModel=False)
eBay = BayesianEnvironment(optimizeModel=False)

Z_dim = 16

fC2_Hidden=60
fC3_Hidden=30
LSTM_input_size=3
Z_hidden=10
critic=Critic(Z_dim,action_dim,h1=fC2_Hidden,h2=fC3_Hidden)
if torch.cuda.is_available():
    critic.cuda()

critic.load_state_dict(torch.load('G:\ResearchGatech\RNS-Meta\InfoGainModels\ParamModelAddedBestParam5-21\critic.pkl', map_location='cpu'))



fC2 = critic.fc2.weight.data.cpu().numpy()
fC2 = fC2.reshape(fC3_Hidden, -1)
fB2 = critic.fc2.bias.data.cpu().numpy()

fC3 = critic.fc3.weight.data.cpu().numpy()
fC3 = fC3.reshape(fC3_Hidden)
fB3 = critic.fc3.bias.data.cpu().numpy()


numIters=10
numSteps=4
##############META#################

DAErrorMeta2=np.zeros((numIters,numSteps+1))
ParamErrorMeta2=np.zeros((numIters,numSteps+1))
allRats=[]
totalTime=0

fR1_Hidden=1
fR2_Hidden=200
fR3_Hidden=50
totalProb=0
for j in range(0,numIters):
    np.random.seed(j*5)
    DAErrorMeta2[j, 0] = 1.0
    ParamErrorMeta2[j,0] = 0
    print("Meta iteration",j)
    ratN = np.random.randint(0, 5)

    allRats.append(ratN)
    (ZPrime, prevS_A,guess,_),r_prev,error,DAError= e.reset(ratNum=ratN,seed=j)
    for i in range(0,numSteps):
        bestP=e.ratEst1.BestGuess
        DAErrorMeta2[j, i +1] = abs(np.multiply(e.sim.simulateRat(bestP.reshape(1,-1)),bestP)-e.testBestDA)
        ParamErrorMeta2[j, i +1]=(4-abs(bestP-e.sim.bestP))/4.0
        fakeA = torch.zeros((1, action_dim)).cuda()
        prevS_A=prevS_A.reshape(1,-1,LSTM_input_size).float()
        start=time.time()
        _, Z,ratNumPred = critic(ZPrime,fakeA,prevS_A,guess)
        Z = Z.detach().cpu().numpy().reshape(fC2_Hidden, -1)

        a= MILP.solve_br(Z=Z, fC2=fC2, fC3=fC3, fB2=fB2, fB3=fB3)
        end=time.time()
        a = np.asarray(a).reshape(-1, 1)

        (ZPrime,prevS_A,guess,_),r1,error,bestParamError,DAError=e.step(a)
        totalTime+=end-start


print("Total Time Meta",totalTime/(numIters*numSteps))
print("average Prob",totalProb/80.0)

DAErrormetaMean2=np.mean(DAErrorMeta2,axis=0)
DAErrormetaSTD2=np.std(DAErrorMeta2,axis=0)
DAErrormetaSTD2=DAErrormetaSTD2/np.sqrt(numIters)



###########Bayesian##################3
DAErrorBayesian2=np.zeros((numIters,numSteps+1))
ParamErrorBayesian2=np.zeros((numIters,numSteps+1))

totalTime=0
for j in range(0,numIters):
    prevA = []
    r_prev,error,DAError = eBay.reset(ratNum=allRats[j],seed=j)
    DAErrorBayesian2[j, 0] = 1.0
    ParamErrorBayesian2[j,0]=0
    print("BAO iteration",j)
    prevA.append(eBay.ratEst.actions)
    totalImprovement=0
    for i in range(0,numSteps):
        bestP=eBay.ratEst.BestGuess
        DAErrorBayesian2[j, i +1] = abs(np.multiply(eBay.sim.simulateRat(bestP.reshape(1,-1)),bestP)-eBay.testBestDA)
        ParamErrorBayesian2[j, i + 1]=(4-abs(bestP-eBay.sim.bestP))/4.0
        np.random.seed()
        start=time.time()
        a = BayesianAcquisitionFunction.acquire_bayesian_ei(eBay.ratEst.rat,eBay.ratEst.actions,eBay.ratEst.out)

        a = np.asarray(a).reshape(-1, 1)
        end=time.time()
        print('action chosen',a)
        r1,error,bestParamError,DAError=eBay.step(a)
        print(r1)
        totalTime+=end-start


print("total time Bay",totalTime/(numIters*numSteps))
##############RANDOM####################
DAErrorRandom2=np.zeros((numIters,numSteps+1))
ParamErrorRandom2=np.zeros((numIters,numSteps+1))

for j in range(0,numIters):
    prevA = []
    error_prev = 1
    r_prev,error,DAError = eBay.reset(ratNum=allRats[j],seed=j)
    DAErrorRandom2[j, 0] = 1.0
    ParamErrorRandom2[j,0]=0
    print("Random iteration",j)
    prevA.append(eBay.ratEst.actions)
    totalImprovement=0
    for i in range(0,numSteps):
        bestP=eBay.ratEst.BestGuess
        DAErrorRandom2[j, i +1] = abs(np.multiply(eBay.sim.simulateRat(bestP.reshape(1,-1)),bestP)-eBay.testBestDA)
        ParamErrorRandom2[j, i + 1]=(4-abs(bestP-eBay.sim.bestP))/4.0
        np.random.seed()
        a=np.random.uniform(low=0.0, high=4.0, size=(1, 1)).reshape(1, 1)
        a = np.asarray(a).reshape(-1, 1)
        print('action chosen',a)
        r1,error,bestParamError,DAError=eBay.step(a)


totalTime=0
lalModel1=joblib.load('lalModel2.sav')
DAErrorLAL2=np.zeros((numIters,numSteps+1))
ParamErrorLAL2=np.zeros((numIters,numSteps+1))
for j in range(0,numIters):
    DAErrorLAL2[j, 0] = 1.0
    ParamErrorLAL2[j, 0] = 0
    r_prev,error,DAError = eBay.reset(ratNum=allRats[j],seed=j)
    print("LAL iteration",j)
    for i in range(0,numSteps):
        bestP = eBay.ratEst.BestGuess
        DAErrorLAL2[j, i + 1] = abs(np.multiply(eBay.sim.simulateRat(bestP.reshape(1, -1)), bestP) - eBay.testBestDA)
        ParamErrorLAL2[j, i + 1] = (4-abs(bestP-eBay.sim.bestP))/4.0
        start = time.time()
        a =findBestA(lalModel1,eBay.ratEst.actions,eBay.ratEst.out,eBay.ratEst.rat,i)
        end = time.time()
        print("action chosen",a)
        r1,error,bestParamError,DAError=eBay.step(a.reshape(1,1))
        totalTime += end - start

print("total time LAL", totalTime / (numIters * numSteps))



DAErrormetaMean2=np.mean(DAErrorMeta2,axis=0)
DAErrormetaSTD2=np.std(DAErrorMeta2,axis=0)
DAErrormetaSTD2=DAErrormetaSTD2/np.sqrt(numIters)


DAErrorbayMean2=np.mean(DAErrorBayesian2,axis=0)
DAErrorbaySTD2=np.std(DAErrorBayesian2,axis=0)
DAErrorbaySTD2=DAErrorbaySTD2/np.sqrt(numIters)

DAErrorrandMean2=np.mean(DAErrorRandom2,axis=0)
DAErrorrandSTD2=np.std(DAErrorRandom2,axis=0)
DAErrorrandSTD2=DAErrorrandSTD2/np.sqrt(numIters)

DAErrorLALMean2=np.mean(DAErrorLAL2,axis=0)
DAErrorLALSTD2=np.std(DAErrorLAL2,axis=0)
DAErrorLALSTD2=DAErrorLALSTD2/np.sqrt(numIters)


ParamErrormetaMean2=np.mean(ParamErrorMeta2,axis=0)
ParamErrormetaSTD2=np.std(ParamErrorMeta2,axis=0)
ParamErrormetaSTD2=ParamErrormetaSTD2/np.sqrt(numIters)


ParamErrorbayMean2=np.mean(ParamErrorBayesian2,axis=0)
ParamErrorbaySTD2=np.std(ParamErrorBayesian2,axis=0)
ParamErrorbaySTD2=ParamErrorbaySTD2/np.sqrt(numIters)

ParamErrorrandMean2=np.mean(ParamErrorRandom2,axis=0)
ParamErrorrandSTD2=np.std(ParamErrorRandom2,axis=0)
ParamErrorrandSTD2=ParamErrorrandSTD2/np.sqrt(numIters)

ParamErrorLALMean2=np.mean(ParamErrorLAL2,axis=0)
ParamErrorLALSTD2=np.std(ParamErrorLAL2,axis=0)
ParamErrorLALSTD2=ParamErrorLALSTD2/np.sqrt(numIters)


pl.figure()
pl.plot(DAErrormetaMean2,color='b')
pl.fill_between(range(numSteps+1), DAErrormetaMean2-DAErrormetaSTD2, DAErrormetaMean2+DAErrormetaSTD2, alpha = 0.5,color='b')

pl.plot(DAErrorbayMean2,color='r')
pl.fill_between(range(numSteps+1), DAErrorbayMean2-DAErrorbaySTD2, DAErrorbayMean2+DAErrorbaySTD2, alpha = 0.5,color='r')


pl.plot(DAErrorrandMean2,color='g')
pl.fill_between(range(numSteps+1), DAErrorrandMean2-DAErrorrandSTD2, DAErrorrandMean2+DAErrorrandSTD2, alpha = 0.5,color='g')


pl.plot(DAErrorLALMean2,color='C1')
pl.fill_between(range(numSteps+1), DAErrorLALMean2-DAErrorLALSTD2, DAErrorLALMean2+DAErrorLALSTD2, alpha = 0.5,color='C1')

pl.title('Model DA Error')
pl.legend(['meta','bayesopt','random_bayesian','LAL'])
#pl.ylim((0,2))
pl.show()


pl.figure()
pl.plot(ParamErrormetaMean2,color='b')
pl.fill_between(range(numSteps+1), ParamErrormetaMean2-ParamErrormetaSTD2, ParamErrormetaMean2+ParamErrormetaSTD2, alpha = 0.5,color='b')

pl.plot(ParamErrorbayMean2,color='r')
pl.fill_between(range(numSteps+1), ParamErrorbayMean2-ParamErrorbaySTD2, ParamErrorbayMean2+ParamErrorbaySTD2, alpha = 0.5,color='r')


pl.plot(ParamErrorrandMean2,color='g')
pl.fill_between(range(numSteps+1), ParamErrorrandMean2-ParamErrorrandSTD2, ParamErrorrandMean2+ParamErrorrandSTD2, alpha = 0.5,color='g')


pl.plot(ParamErrorLALMean2,color='C1')
pl.fill_between(range(numSteps+1), ParamErrorLALMean2-ParamErrorLALSTD2, ParamErrorLALMean2+ParamErrorLALSTD2, alpha = 0.5,color='C1')

pl.title('Model Param Error')
pl.legend(['Ours','BayesOpt','Random Bayesian','LAL'])
#pl.ylim((0,2))
pl.show()
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\DAMeta.csv", np.vstack((DAErrormetaMean2,DAErrormetaSTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\DABay.csv", np.vstack((DAErrorbayMean2,DAErrorbaySTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\DARand.csv", np.vstack((DAErrorrandMean2,DAErrorrandSTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\DALAL.csv", np.vstack((DAErrorLALMean2,DAErrorLALSTD2)), fmt='%5s',delimiter=",")


np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\ParamMeta.csv", np.vstack((ParamErrormetaMean2,ParamErrormetaSTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\ParamBay.csv", np.vstack((ParamErrorbayMean2,ParamErrorbaySTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\ParamRand.csv", np.vstack((ParamErrorrandMean2,ParamErrorrandSTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\ParamLAL.csv", np.vstack((ParamErrorLALMean2,ParamErrorLALSTD2)), fmt='%5s',delimiter=",")






DAErrorBO2=np.zeros((numIters,numSteps+1))
ParamErrorBO2=np.zeros((numIters,numSteps+1))
totalTime=0
for j in range(0,numIters):
    prevA = []
    error_prev = 1
    DAErrorBO2[j,0] = 1.0
    ParamErrorBO2[j,0]=0
    print("Meta BO iteration",j)
    prevA.append(eBay.ratEst.actions)
    totalImprovement=0
    start=time.time()
    p,d,_=MBO.callMetaBO(ratN=allRats[j])
    end=time.time()
    ParamErrorBO2[j, 1:]=np.array(p).reshape(-1,)
    DAErrorBO2[j,1:]=np.array(d).reshape(-1,)
    totalTime+=end-start

print("total time metaBO",totalTime/(numIters*numSteps))

DAErrorBOMean2=np.mean(DAErrorBO2,axis=0)
DAErrorBOSTD2=np.std(DAErrorBO2,axis=0)
DAErrorBOSTD2=DAErrorBOSTD2/np.sqrt(numIters)


ParamErrorBOMean2=np.mean(ParamErrorBO2,axis=0)
ParamErrorBOSTD2=np.std(ParamErrorBO2,axis=0)
ParamErrorBOSTD2=ParamErrorBOSTD2/np.sqrt(numIters)


np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\ParamBO.csv", np.vstack((ParamErrorBOMean2,ParamErrorBOSTD2)), fmt='%5s',delimiter=",")
np.savetxt("G:\ResearchGatech\RNS-Meta\InfoGain_Outputs\DABO.csv", np.vstack((DAErrorBOMean2,DAErrorBOSTD2)), fmt='%5s',delimiter=",")



print('MetaMean2ErrorDA',DAErrormetaMean2)
print('MetaSTD2DA',DAErrormetaSTD2)


print('BayesianMean2ErrorDA',DAErrorbayMean2)
print('BayesianSTDDA',DAErrorbaySTD2)


print('RandomDA',DAErrorrandMean2)
print('RandomSTDDA',DAErrorrandSTD2)


print('BODA',DAErrorBOMean2)
print('BOSTDDA',DAErrorBOSTD2)



pl.figure()
pl.plot(DAErrormetaMean2,color='b')
pl.fill_between(range(numSteps+1), DAErrormetaMean2-DAErrormetaSTD2, DAErrormetaMean2+DAErrormetaSTD2, alpha = 0.5,color='b')

pl.plot(DAErrorbayMean2,color='r')
pl.fill_between(range(numSteps+1), DAErrorbayMean2-DAErrorbaySTD2, DAErrorbayMean2+DAErrorbaySTD2, alpha = 0.5,color='r')


pl.plot(DAErrorrandMean2,color='g')
pl.fill_between(range(numSteps+1), DAErrorrandMean2-DAErrorrandSTD2, DAErrorrandMean2+DAErrorrandSTD2, alpha = 0.5,color='g')


pl.plot(DAErrorBOMean2,color='y')
pl.fill_between(range(numSteps+1), DAErrorBOMean2-DAErrorBOSTD2, DAErrorBOMean2+DAErrorBOSTD2, alpha = 0.5,color='y')

pl.plot(DAErrorLALMean2,color='C1')
pl.fill_between(range(numSteps+1), DAErrorLALMean2-DAErrorLALSTD2, DAErrorLALMean2+DAErrorLALSTD2, alpha = 0.5,color='C1')

pl.title('Model DA Error')
pl.legend(['meta','bayesopt','random_bayesian','LAL'])
#pl.ylim((0,2))
pl.show()


pl.figure()
pl.plot(ParamErrormetaMean2,color='b')
pl.fill_between(range(numSteps+1), ParamErrormetaMean2-ParamErrormetaSTD2, ParamErrormetaMean2+ParamErrormetaSTD2, alpha = 0.5,color='b')

pl.plot(ParamErrorbayMean2,color='r')
pl.fill_between(range(numSteps+1), ParamErrorbayMean2-ParamErrorbaySTD2, ParamErrorbayMean2+ParamErrorbaySTD2, alpha = 0.5,color='r')


pl.plot(ParamErrorrandMean2,color='g')
pl.fill_between(range(numSteps+1), ParamErrorrandMean2-ParamErrorrandSTD2, ParamErrorrandMean2+ParamErrorrandSTD2, alpha = 0.5,color='g')

pl.plot(ParamErrorBOMean2,color='y')
pl.fill_between(range(numSteps+1), ParamErrorBOMean2-ParamErrorBOSTD2, ParamErrorBOMean2+ParamErrorBOSTD2, alpha = 0.5,color='y')

pl.plot(ParamErrorLALMean2,color='C1')
pl.fill_between(range(numSteps+1), ParamErrorLALMean2-ParamErrorLALSTD2, ParamErrorLALMean2+ParamErrorLALSTD2, alpha = 0.5,color='C1')

pl.title('Model Param Error')
pl.legend(['Ours','BayesOpt','Random Bayesian','MetaBO','LAL'])
#pl.ylim((0,2))
pl.show()


