from Encoding.LSTMDynamics import LSTMDyn
import torch
import numpy as np
from Encoding.LSTMUncertainty import LSTMU

class Encode():
    def __init__(self):
        stateSize=11
        actionSize=4


        if torch.cuda.is_available():
            self.modelDyn = LSTMDyn(input_size=stateSize + actionSize, state_size=stateSize).cuda()
            self.modelDyn.load_state_dict(torch.load("C:/Users/mschr/CORELABProjects/Meta-Learning-Aircraft/Encoding/LSTMDynamics"))

            self.modelU = LSTMU(input_size=stateSize + actionSize, state_size=stateSize).cuda()
            self.modelU.load_state_dict(
                torch.load("C:/Users/mschr/CORELABProjects/Meta-Learning-Aircraft/Encoding/LSTMUncertainty"))

        else:
            self.modelDyn = LSTMDyn(input_size=stateSize + actionSize, state_size=stateSize)
            self.modelDyn.load_state_dict(
                torch.load("E:/ResearchGatech/Meta-Learning Aircraft/LSTMDynamics", map_location='cpu'))

            self.modelU = LSTMU(input_size=stateSize + actionSize, state_size=stateSize)
            self.modelU.load_state_dict(
                torch.load("E:/ResearchGatech/Meta-Learning Aircraft/LSTMUncertainty",map_location='cpu'))

        self.modelDyn.eval()

        self.modelU.eval()


    def encode(self,prevStates,prevActions,prevDS):


        prevStates=prevStates.reshape(-1,11)
        prevActions=prevActions.reshape(-1,4)
        prevDS=prevDS.reshape(-1,11)

        input = np.hstack((prevActions, prevStates)).reshape(1,-1,15)

        if torch.cuda.is_available():
            input = torch.from_numpy(input).float().cuda()
        else:
            input = torch.from_numpy(input).float()

        #input = torch.from_numpy(input).float()

        A, B,_,zDyn=self.modelDyn(input)
        zDyn=zDyn.detach()

        A, zU= self.modelU(input)
        zU= zU.detach()
        states_mean=np.mean(prevStates,axis=0)
        actions_mean = np.mean(prevActions, axis=0)
        dS_mean = np.mean(prevDS, axis=0)

        states_std = np.std(prevStates, axis=0)
        actions_std = np.std(prevActions, axis=0)
        dS_std= np.std(prevDS, axis=0)
        zDyn=zDyn.reshape(-1).cpu()
        zU=zU.reshape(-1).cpu()

        Z=np.concatenate((zDyn,states_mean,actions_mean,dS_mean,states_std,actions_std,dS_std))
        #print(Z.shape)
        Z=torch.from_numpy(Z).float().reshape(1,-1)


        return Z


#pad sequences
#import LSTM models
#pass data through 3 LSTM models
#concatenate result to get ZPrime