import torch
from MetaEnvironment import Environment
from DDPG.model import *

from sklearn.ensemble import RandomForestRegressor
import joblib

def calcDiff(prevA,prevO,a,e,numS=1):

    o=e.simDamage(a)
    a=a.reshape(4,-1)
    o=o.reshape(11,-1)
    allDiffA=np.zeros((numS,8))
    allDiffO = np.zeros((numS, 22))
    dA1=(prevA[:,0].reshape(-1,1)-a).reshape(-1,4)
    dA2=(prevA[:,1].reshape(-1,1)-a).reshape(-1,4)

    allDiffA=np.hstack((dA1,dA2))
    allDiffO=np.hstack(((prevO[:,0].reshape(-1,1)-o).reshape(-1,11),(prevO[:,1].reshape(-1,1)-o).reshape(-1,11)))

    return np.hstack((allDiffA,allDiffO))

def calcU(a,e):
    u=e.getUncertainty(a)
    return u

def findBestA(lal,e,step):
    numSamples=30

    r = np.array([10, 10, 100, 50, 50, 50, 50, 10, 10, 300]).reshape(10, -1)*5
    xref = np.array([5, 0, 0, 0, 0, 0, 0, 0, 0, 1000]).reshape(10, -1)
    upperBound=np.tile((r+xref),(1,numSamples*numSamples*numSamples*numSamples))
    lowerBound = np.tile((xref-r), (1, numSamples*numSamples*numSamples*numSamples))
    elevator = np.atleast_2d(np.linspace(-3, 3, numSamples)).T
    thrust = np.atleast_2d(np.linspace(6000, 8000, numSamples)).T
    aileron = np.atleast_2d(np.linspace(-5, 5, numSamples)).T
    rudder = np.atleast_2d(np.linspace(-2, 2, numSamples)).T
    possA=np.array(np.meshgrid(elevator, thrust, aileron,rudder))

    possA=possA.reshape(numSamples,numSamples,numSamples,4,numSamples)
    possA=possA.T.T.reshape(4,-1)
    out = e.sim.simulateDamage(np.tile(e.sim.currentState, (1,numSamples*numSamples*numSamples*numSamples)),possA)
    out2 = e.sim.simulateDamage(out,
                               possA)
    out2=np.vstack((out2[0:9,:],out[10,:]))
    print("action",possA[:,0:2])
    print("out",out2[:,0:2])
    print("S",out2.shape)
    print(upperBound.shape)
    indicator = np.logical_and(out2 < upperBound,out2 > lowerBound)
    print("indicator",indicator[:,0:2])
    #indicator=np.invert(indicator)
    indicator2=np.all(indicator,axis=0)
    #indicator2 = indicator2.reshape(numSamples, numSamples, numSamples, numSamples)
    indicator3 = np.where(indicator2)
    print(indicator3)
    possA=possA[:,indicator3]
    print("PoSSA",possA.shape)
    possA=possA.reshape(4,-1)


    print("I3",indicator3)


    """for i in range(numSamples):
        print("IIIIII",i)
        for j in range(numSamples):
            for k in range(numSamples):
                for l in range(numSamples):
                    a=np.hstack((elevator[i],thrust[j],aileron[k],rudder[l]))
                    a=a.reshape(4,1)
                    uncertainty=calcU(a,e)
                    diff=calcDiff(e.d.actions[:,-2:],e.d.states[:,-2:],a,e)
                    allFeatures= np.hstack((a.reshape(-1,4), diff, uncertainty.reshape(-1,1), np.array([step]).reshape(-1,1)))
                    r=lal.predict(allFeatures)
                    allR[i,j,k,l]=r
                    """
    uncertainty = calcU(possA, e)
    diff = calcDiff(e.d.actions[:, -2:], e.d.states[:, -2:], possA, e)
    #possA = possA.reshape(4, numSamples, numSamples, numSamples, numSamples)
    #possA = possA.T.reshape(-1, 4)
    allFeatures = np.concatenate((possA.reshape(-1,4), diff.reshape(-1, 30), uncertainty.reshape(-1,1),np.tile(step,(possA.shape[1],1))),axis=1)
    allFeatures=allFeatures.reshape(-1,36)
    print(allFeatures.shape,"ALL")
    r=lal.predict(allFeatures)
    index=np.argmax(r)
    print(allFeatures)
    a=allFeatures[index,0:4]

    #index=np.argmax(np.array(allR))
    print("I",i)

    return a





numIters=1000
numSteps=5
"""
##############RANDOM####################
featNum=36
allFeatures=np.zeros((numIters*numSteps,featNum))
allError=np.zeros(numIters*numSteps)
c=0
for j in range(0,numIters):
    print("Iter",j)
    np.random.seed(j)
    d1 = np.random.randint(0, 2)
    d2 = np.random.randint(0, 4)
    p1 = np.random.uniform(low=0, high=1)
    p2 = np.random.uniform(low=.75, high=1)
    damage=np.array([d1,         d2 ,        p1, p2])
    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damage, simAll=True)
    stateCurrent,uCurrent,error=e.genInitial()
    prevError=error

    for i in range(0,numSteps):
        np.random.seed(None)
        a=np.random.uniform(low=0.0, high=4.0, size=(4, 1))
        a = np.asarray(a).reshape(-1, 1)
        uncertainty=calcU(a,e)
        diff=calcDiff(e.d.actions[:,-2:],e.d.states[:,-2:],a,e)

        nextState, r1, error = e.step(a.reshape(4, 1))
        allError[c]=prevError-error
        allFeatures[c,:]=np.hstack((a.reshape(-1,4),diff,uncertainty.reshape(1,1),np.array([i]).reshape(1,1)))

        prevError=error
        c+=1


parameters = {'est': 2000, 'depth': 40, 'feat': 6 }
lalModel1 = RandomForestRegressor(n_estimators = parameters['est'], max_depth = parameters['depth'],
                                 max_features=parameters['feat'], oob_score=True, n_jobs=8)

lalModel1.fit(allFeatures, np.ravel(allError))
print('Oob score = ', lalModel1.oob_score_)
joblib.dump(lalModel1, 'lalPlaneNN.sav')
"""
lalModel1=joblib.load('lalPlaneNN.sav')
numIters=10
numSteps=8

for j in range(0,numIters):
    print("Iter",j)
    np.random.seed(j)
    d1 = np.random.randint(0, 2)
    d2 = np.random.randint(0, 4)
    p1 = np.random.uniform(low=0, high=1)
    p2 = np.random.uniform(low=.75, high=1)
    damage=np.array([d1,         d2 ,        p1, p2])
    e = Environment(np.array([4.84, 0, 0, 0, 0, 0, 0, 0, 0, 100, 1001]), damageNum=damage, simAll=True)
    stateCurrent,uCurrent,error=e.genInitial()
    prevError=error

    for i in range(0,numSteps):
        a =findBestA(lalModel1,e,i)
        print("action",a)
        nextState, r1, error = e.step(a.reshape(4, 1))
