from DamageBayesian import Damage
from Simulation import Simulation
from Simulation import SimulationAllDamage
import numpy as np
from encodeData import Encode

from scipy.interpolate import interp1d

class Environment():
    def __init__(self,currentState,enforceBoundary=True,damageNum=[1,1,0,0],simAll=False):
        if simAll==False:
            self.sim=Simulation(damageNum=damageNum)
            self.damageNum = damageNum
        else:
            self.damageNum=damageNum
            self.sim=SimulationAllDamage(damageNum=damageNum)

        self.sim.currentState=currentState.reshape(11,1)
        self.action_space_high=np.array([35,8000000,35,35,35,8000000,35,55])
        self.action_space_low=np.array([-35,100,-35,-35,-35,100,-35,-35])
        self.t=.05
        self.enforceBoundary=enforceBoundary

        self.d =Damage()
        self.testS, self.testA, self.testDS = self.genTestingData()

        self.encoder=Encode()

    def step2(self,action):
        isDone = False
        a1=action[0:4]
        a2=action[4:8]
        a3=action[8:12]
        u0=a3

        nextState1=self.sim.simulateDamage(self.sim.currentState,a1)
        deltaS1=(nextState1.reshape(11,-1)-self.sim.currentState.reshape(11,-1))/self.t
        self.d.updateModel(self.sim.currentState,a1,deltaS1)
        self.sim.currentState = nextState1

        nextState2=self.sim.simulateDamage(self.sim.currentState,a2)
        deltaS2=(nextState2.reshape(11,-1)-self.sim.currentState.reshape(11,-1))/self.t
        self.d.updateModel(self.sim.currentState,a2,deltaS2)
        self.sim.currentState = nextState2

        nextState3 = self.sim.simulateDamage(self.sim.currentState, a3)
        deltaS3 = (nextState3.reshape(11, -1) - self.sim.currentState.reshape(11, -1)) / self.t
        self.d.updateModel(self.sim.currentState, a3, deltaS3)
        r1, error = self.d.getReward(self.testS, self.testA, self.testDS)
        self.sim.currentState = nextState3

        nextState3 = nextState3.reshape(-1)

        prevStates=self.d.states
        prevActions=self.d.actions
        prevDS=self.d.deltS

        Z=self.encoder.encode(prevStates,prevActions,prevDS)
        prevS_A=np.vstack((prevStates,prevActions))

        ########################nextState should be self.d.states self.d.actions
        return (self.d.A,self.d.B,self.d.Asigma,self.d.Bsigma,Z,nextState3,u0,prevS_A), r1, isDone,0

    def step(self, action):
        nextState = self.sim.simulateDamage(self.sim.currentState, action)
        deltaS = (nextState.reshape(11, -1) - self.sim.currentState.reshape(11, -1)) / self.t

        self.d.updateModel(self.sim.currentState, action, deltaS)
        nextState = nextState.reshape(-1)
        self.sim.currentState = nextState

        r1,error = self.d.getReward(self.testS, self.testA, self.testDS)

        return nextState,r1,error

    def reset(self):
        print('reset')
        self.d.reset()
        np.random.seed()
        d1 = np.random.randint(0, 3)
        d2 = np.random.randint(0, 5)
        p1 = np.random.uniform(low=0, high=1)
        p2 = np.random.uniform(low=.75, high=1)
        self.newSimulation(np.array([d1,         d2 ,        p1, p2]))


        s1=np.array([4.84,.002,.0001,.001,0,0,0,.01,-.01,100,1001]).reshape(11,1)
        a=np.array([0,7269,0,0])
        for i in range(3):
            s2 = self.sim.simulateDamage(s1, a)
            deltS1 = (s2.reshape(11,-1) - s1.reshape(11,-1)) / self.sim.t
            self.d.updateModel(s1, a, deltS1)
            s1=s2
            a=np.array([0,7269,0,0])+np.random.rand(4)*.01

        self.sim.currentState=s2

        prevStates = self.d.states
        prevActions = self.d.actions
        prevDS=self.d.deltS
        prevS_A = np.vstack((prevStates, prevActions))


        return (self.d.A,self.d.B,self.d.Asigma,self.d.Bsigma,s2,a,prevS_A)

    def genInitial(self):
        s1=np.array([4.84,.002,.0001,.001,0,0,0,.01,-.01,100,1001]).reshape(11,1)
        a=np.array([0,7269,0,0])
        for i in range(3):
            s2 = self.sim.simulateDamage(s1, a)
            deltS1 = (s2.reshape(11,-1) - s1.reshape(11,-1)) / self.sim.t
            self.d.updateModel(s1, a, deltS1)
            s1=s2
            a=np.array([0,7269,0,0])+np.random.rand(4)*.01

        self.sim.currentState=s2
        prevStates = self.d.states
        prevActions = self.d.actions
        prevDS=self.d.deltS

        Z = self.encoder.encode(prevStates, prevActions,prevDS)
        r,e=self.d.getReward(self.testS, self.testA, self.testDS)

        return s2,a,e

    def genTestingData(self):
        numData = 5
        np.random.seed(1)
        fVel = np.random.uniform(0, 200, (1, numData))
        vVel = np.random.uniform(-1, 1, (1, numData))
        pr = np.random.uniform(-1, 1, (1, numData))
        pa = np.random.uniform(-1, 1, (1, numData))
        sa = np.random.uniform(-1, 1, (1, numData))
        rr = np.random.uniform(-1, 1, (1, numData))
        yr = np.random.uniform(-1, 1, (1, numData))
        roll = np.random.uniform(-1, 1, (1, numData))
        yaw = np.random.uniform(-1, 1, (1, numData))
        X = np.random.uniform(0, 4000, (1, numData))
        Z = np.random.uniform(-10, 10, (1, numData))

        testInputsStates = np.vstack((fVel, vVel, pr, pa, sa, rr, yr, roll, yaw, X, Z))

        elevator = np.random.uniform(-1, 1, (1, numData))
        thrust = np.random.uniform(0, 10000, (1, numData))
        aileron = np.random.uniform(-1, 1, (1, numData))
        rudder = np.random.uniform(-1, 1, (1, numData))

        testInputsActions = np.vstack((elevator, thrust, aileron, rudder))
        testStates2 = self.sim.simulateDamage(testInputsStates, testInputsActions)
        testOutputs = (testStates2 - testInputsStates) / self.sim.t

        return testInputsStates, testInputsActions, testOutputs

    def get_StateError(self):
        testStates, testActions, testOutputs=self.genTestingData()
        dSEst = np.dot(self.d.A, testStates) + np.dot(self.d.B, testActions)
        n=testStates.shape[1]
        error=np.sum(np.sqrt(np.square(dSEst - testOutputs)), axis=1) / n

        return error

    def newSimulation(self,damageNum):
        self.damageNum=damageNum
        self.sim=SimulationAllDamage(damageNum=damageNum)
        self.testS, self.testA, self.testDS = self.genTestingData()