import torch
from MetaEnvironment import Environment
from BayesianOpt.MetaBayesian import BayesianEnvironment
from DDPG.model import *
import MILPs.MILPModel_2layer as MILP
import MILPs.infoGainMILP_ChanceConstraints as MILPSafety
from BayesianOpt import BayesianAcquisitionFunction
from matplotlib import pyplot as pl
import MetaLearnBO.run_experiments as MBO
import time
from scipy.stats import norm
from sklearn.ensemble import RandomForestRegressor
import joblib

def calcDiff(prevA,prevO,a,nn):
    o=nn(torch.from_numpy(a).float())
    o=o.detach().numpy()
    allDiffA=np.zeros((1,8))
    allDiffO = np.zeros((1, 8))
    for i in range(prevA.shape[0]):
        allDiffA[0,i]=prevA[i]-a
        allDiffO[0,i]=prevO[i]-o
    return np.hstack((allDiffA,allDiffO))

def calcU(nn1,nn2,a):
    out1=nn1(torch.from_numpy(a).float())
    out1=out1.detach().numpy()
    out2=nn2(torch.from_numpy(a).float())
    out2=out2.detach().numpy()
    U=np.sum(np.linalg.norm(out1 - out2, axis=0))
    print(U)
    return U.reshape(1,1)

def findBestA(lal,actions,out,nn1,nn2,step):
    numSamples=50
    possA=np.atleast_2d(np.linspace(0, 4, numSamples)).T
    allR=[]
    for i in range(numSamples):
        uncertainty=calcU(nn1,nn2,possA[i].reshape(1,-1))
        diff=calcDiff(actions,out,possA[i].reshape(1,-1),nn1)
        allFeatures= np.hstack((possA[i].reshape(1,1), diff, uncertainty, np.array([step]).reshape(1, 1)))
        r=lal.predict(allFeatures)
        allR.append(r)

    i=np.argmax(np.array(allR))
    return possA[i]



e = Environment(optimizeModel=False)

numIters=500
numSteps=8
##############RANDOM####################
DAErrorRandom2=np.zeros((numIters,numSteps+1))
ParamErrorRandom2=np.zeros((numIters,numSteps+1))

featNum=19
allFeatures=np.zeros((numIters*numSteps,featNum))
allError=np.zeros(numIters*numSteps)
c=0
for j in range(0,numIters):
    print("Iter",j)
    np.random.seed(j)
    ratN = np.random.randint(0, 5)
    print('rat',ratN)
    states,r_prev,error,DAError = e.reset(ratNum=ratN,seed=j)
    prevError=error

    for i in range(0,numSteps):
        np.random.seed(None)
        a=np.random.uniform(low=0.0, high=4.0, size=(1, 1)).reshape(1, 1)
        a = np.asarray(a).reshape(-1, 1)
        uncertainty=calcU(e.allRatsEst[0].rat,e.allRatsEst[1].rat,a)
        diff=calcDiff(e.allRatsEst[0].actions,e.allRatsEst[0].out,a,e.ratEst1.rat)

        _,r1,error,bestParamError,DAError=e.step(a)
        allError[c]=r1
        allFeatures[c,:]=np.hstack((a,diff,uncertainty,np.array([i]).reshape(1,1)))

        prevError=error
        c+=1


parameters = {'est': 2000, 'depth': 40, 'feat': 6 }
lalModel1 = RandomForestRegressor(n_estimators = parameters['est'], max_depth = parameters['depth'],
                                 max_features=parameters['feat'], oob_score=True, n_jobs=8)

lalModel1.fit(allFeatures, np.ravel(allError))
print('Oob score = ', lalModel1.oob_score_)
joblib.dump(lalModel1, 'lalModelNN.sav')

lalModel1=joblib.load('lalModelNN.sav')
numIters=10
numSteps=8

for j in range(0,numIters):
    np.random.seed(j)
    print("Iter",j)
    ratN = np.random.randint(0, 5)
    print("rat",ratN)
    _,r_prev,error,DAError = e.reset(ratNum=ratN,seed=j)
    prevError=error

    for i in range(0,numSteps):
        a =findBestA(lalModel1,e.allRatsEst[0].actions,e.allRatsEst[0].out,e.allRatsEst[0].rat,e.allRatsEst[1].rat,i)
        print("action",a)
        _,r1,error,bestParamError,DAError=e.step(a.reshape(1,1))
