import numpy as np


class Damage:
    def __init__(self,states=np.array([]),actions=np.array([]),deltS=np.array([])):
        self.AnominalEst=np.array([[-0.0220400000000000,0.00142200000000000,0,-32.1700000000000,0,0,0,0,0,0,0],
            [-0.0575900000000000,-0.394300000000000,871,0,0,0,0,0,0,0,0],
            [-9.50000000000000e-05,-0.00179300000000000,-0.546800000000000,0,0,0,0,0,0,0,0],
            [0,0,1,0,0,0,0,0,0,0,0],
            [0,0,0,0,-0.0639900000000000,0,-1,0.0369100000000000,0,0,0],
            [0,0,0,0,-2.11100000000000,-0.504800000000000,0.193000000000000,0,0,0,0],
            [0,0,0,0,0.773200000000000,0.0108700000000000,-0.179400000000000,0,0,0,0],
            [0,0,0,0,0,1,0,0,0,0,0],
            [0,0,0,0,0,0,1,0,0,0,0],
            [1,0,0,0,0,0,0,0,0,0,0],
            [0,1,0,0,0,0,0,0,0,0,0]])
        self.BnominalEst=np.array([[0,1.91000000000000e-05,0,0],
            [-18.5800000000000,0,0,0],
            [-1.20000000000000,0,0,0],
            [0,0,0,0],
            [0,0,0,0.00839000000000000],
            [0,0,0.184800000000000,0.105600000000000],
            [0,0,-0.00869900000000000,-0.459100000000000],
            [0,0,0,0],
            [0,0,0,0],
            [0,0,0,0],
            [0,0,0,0]])


        self.A=np.zeros((11,11))
        self.B=np.zeros((11,4))
        self.Asigma=np.zeros((11,11))
        self.Bsigma=np.zeros((11,4))

        self.states=states
        self.actions=actions
        self.deltS=deltS

        self.errorTrain=None
        self.stateError=None

        self.r_prev=3000

        self.A1=[]
        self.A2=[]
        self.B1=[]
        self.B2=[]
        self.A3=[]
        self.B3=[]

    def updateModel(self,newinputsStates=np.array([]),newinputsActions=np.array([]),newoutputs=np.array([])):
        if newinputsStates.shape[0]==0 and self.states.shape[0]==0:
            raise ValueError('No data')
        elif newinputsStates.shape[0]==0:
            inputsStates=self.states
            inputsActions=self.actions
            outputs=self.deltS
        elif self.states.shape[0]==0:
            newinputsActions = newinputsActions.reshape(4, -1)
            newinputsStates = newinputsStates.reshape(11, -1)
            inputsStates=newinputsStates
            inputsActions=newinputsActions
            outputs=newoutputs.reshape(11,-1)
        else:
            newinputsActions = newinputsActions.reshape(4, -1)
            newinputsStates = newinputsStates.reshape(11, -1)
            newoutputs = newoutputs.reshape(11, -1)
            inputsStates=np.hstack((self.states,newinputsStates))
            inputsActions=np.hstack((self.actions,newinputsActions))
            outputs=np.hstack((self.deltS,newoutputs))

        deltaSEst=self.Nominal(inputsStates,inputsActions)
        inputs=np.vstack((inputsActions,deltaSEst))
        n=inputs.shape[1]
        numNets=100
        NETS=np.zeros((self.A.shape[1]+self.B.shape[1],self.A.shape[0],numNets))
        errorTrain=np.zeros((numNets))
        for i in range(numNets):
            if n==1:
                xBoot=inputs
                yBoot=outputs
            else:
                bootSamples = np.random.randint(low=0, high=n-1, size=n)
                xBoot=inputs[:, bootSamples]
                yBoot=outputs[:, bootSamples]

            NETS[:,:,i] = np.dot(np.linalg.pinv(np.transpose(xBoot)),np.transpose(yBoot))
            errorTrain[i] = np.sum(np.linalg.norm(np.dot(np.transpose(NETS[:,:,i]), inputs)-outputs,axis=0))/n
        allErrorTrain=np.sum(errorTrain,axis=0)/numNets

        idx = np.argpartition(errorTrain, 3)
        dynamics=np.array([NETS[:,:,idx[0]], NETS[:,:,idx[1]], NETS[:,:,idx[2]]])
        ABest= np.dot(dynamics[:,4:,:],self.AnominalEst)
        BBest = np.transpose(dynamics[:,:4,:],(0,2,1))+np.dot(dynamics[:,4:,:].transpose((0,2,1)),self.BnominalEst)


        NETSReshaped=np.transpose(NETS,(2,0,1))
        AAll=np.dot(NETSReshaped[:,4:,:],self.AnominalEst)
        BAll=np.transpose(NETSReshaped[:,:4,:],(0,2,1))+np.dot(NETSReshaped[:,4:,:].transpose((0,2,1)),self.BnominalEst)

        AAverage=np.mean(ABest, 0)
        BAverage=np.mean(BBest,0)
        Asigma=np.std(AAll,0)/n
        Bsigma=np.std(BAll,0)/n

        #AAverage[AAverage<1e-10]=0
        #BAverage[BAverage < 1e-10] = 0
        self.A=AAverage
        self.B=BAverage
        self.Asigma=Asigma
        self.Bsigma=Bsigma

        self.errorTrain=allErrorTrain

        est=np.dot(self.A,inputsStates)+np.dot(self.B, inputsActions)

        sError=np.sum(np.sqrt(np.square(est-outputs)),axis=1)/n

        self.stateError=sError


        self.states=inputsStates
        self.actions=inputsActions
        self.deltS=outputs

        self.A1=ABest[0,:,:]
        self.A2=ABest[1,:,:]
        self.A3 = ABest[2, :, :]
        self.B1=BBest[0,:,:]
        self.B2=BBest[1,:,:]
        self.B3 = BBest[2, :, :]

    def getReward(self,testStates,testActions,testDeltaS):
        dS=np.dot(self.A,testStates)+np.dot(self.B,testActions)
        r=np.sum(np.linalg.norm(dS - testDeltaS, axis=0))/testStates.shape[1]
        rNew=(self.r_prev-r)/self.r_prev
        self.r_prev=r
        return rNew,r

    def Nominal(self,state,action):
        action=action.reshape(4,-1)
        state=state.reshape(11,-1)
        deltaS = np.dot(self.AnominalEst, state) + np.dot(self.BnominalEst, action)
        return deltaS

    def reset(self):
        self.states = np.array([])
        self.actions = np.array([])
        self.deltS = np.array([])
        self.r_prev=3000
