
from torch import nn
from torch.nn import functional as F
import torch
# from config import DefaultConfig
import torchvision.models as models

# from .LSTM_Memory import MemLSTMCell

import time
from collections import OrderedDict

class StressLstm(nn.Module):

    def __init__(self):
        super(StressLstm,self).__init__()
        self.model_name = 'StressLstm'
        # self.opt = DefaultConfig()

        # self.embedlinear = nn.Linear(self.opt.dim,256)
        # self.drop = nn.Dropout(0.5)
        # self.postlstm = nn.LSTM(input_size = self.opt.dim, hidden_size = self.opt.hidden, batch_first=True ,dropout=0.5)
        self.postgru = nn.GRU(input_size = 768, hidden_size = 256, batch_first=True ,dropout=0.75)
        #with attention
        self.postlinear = nn.Linear(256,1)
        #wo attention
        #self.postlinear = nn.Linear(80,1)

   


    def forward(self, x):
        # x = self.embedlinear(x)
        
        # with attention
        postoutput, hn = self.postgru(x)#postoutput:batch*lengthoftxt*256
        es = self.postlinear(postoutput) #batch*lengthoftxt*1
        postoutput = postoutput.permute(0,2,1)#postoutput:batch*256*lengthoftxt
        es = torch.bmm(postoutput,es) #
        es = torch.squeeze(es)
        
        #wo attetion
        #postoutput, hn = self.postgru(x)
        #postoutput = postoutput.permute(0,2,1)
        #es = self.postlinear(postoutput) #32*256*1
        #es = torch.squeeze(es)
        
        return es

