from Encoding.LSTMDynamics import LSTMDyn
import torch
import numpy as np
from Encoding.LSTMUncertainty import LSTMU

class Encode():
    def __init__(self):
        stateSize=11
        actionSize=4


        """if torch.cuda.is_available():
            self.modelDyn = LSTMDyn(input_size=stateSize + actionSize, state_size=stateSize).cuda()
            self.modelDyn.load_state_dict(torch.load("C:/Users/mschr/CORELABProjects/Meta-Learning-Aircraft/Encoding/LSTMDynamics"))

            self.modelU = LSTMU(input_size=stateSize + actionSize, state_size=stateSize).cuda()
            self.modelU.load_state_dict(
                torch.load("C:/Users/mschr/CORELABProjects/Meta-Learning-Aircraft/Encoding/LSTMUncertainty"))

        else:
            self.modelDyn = LSTMDyn(input_size=stateSize + actionSize, state_size=stateSize)
            self.modelDyn.load_state_dict(
                torch.load("E:/ResearchGatech/Meta-Learning Aircraft/LSTMDynamics", map_location='cpu'))

            self.modelU = LSTMU(input_size=stateSize + actionSize, state_size=stateSize)
            self.modelU.load_state_dict(
                torch.load("E:/ResearchGatech/Meta-Learning Aircraft/LSTMUncertainty",map_location='cpu'))

        self.modelDyn.eval()

        self.modelU.eval()
        """


    def encode(self,prevActions,prevDS):


        prevActions=prevActions.reshape(-1,1)
        prevDS=prevDS.reshape(-1,1)

        input = prevActions.reshape(1,-1,1)

        if torch.cuda.is_available():
            input = torch.from_numpy(input).float().cuda()
        else:
            input = torch.from_numpy(input).float()

        #input = torch.from_numpy(input).float()

        #A, B,_,zDyn=self.modelDyn(input)
        #zDyn=zDyn.detach()

        #A, zU= self.modelU(input)
        #zU= zU.detach()

        actions_mean = np.mean(prevActions, axis=0)
        dS_mean = np.mean(prevDS, axis=0)


        actions_std = np.std(prevActions, axis=0)
        dS_std= np.std(prevDS, axis=0)
        #zDyn=zDyn.reshape(-1).cpu()
        #zU=zU.reshape(-1).cpu()

        Z=np.concatenate((actions_mean,dS_mean,actions_std,dS_std))
        #print(Z.shape)
        Z=torch.from_numpy(Z).float().reshape(1,-1)


        return Z


#pad sequences
#import LSTM models
#pass data through 3 LSTM models
#concatenate result to get ZPrime