import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from lib.dataloader import get_test_loader, get_train_loader, get_val_loader
import os
import sys
from lib.StressLstm import StressLstm
from lib import slowfastnet
__all__ = ['resnet50', 'resnet101','resnet152', 'resnet200']



class Ours(nn.Module):
    expansion = 4

    def __init__(self, num_classes, hidden_dim,nheads,num_encoder_layers,num_decoder_layers):
        super(Ours, self).__init__()
        
        self.vis_model=slowfastnet.resnet18(class_num=2)#batch*64


        self.segment_lstm = StressLstm()
        self.seq_lstm = nn.GRU(input_size = 256, hidden_size = 64, batch_first=True ,dropout=0.5)
        self.linear1 = nn.Linear(64,1)
        self.fc_text = nn.Linear(64,8)
        self.fc = nn.Linear(72,16)
        self.fc2 = nn.Linear(16, num_classes)
        

    def forward(self, videos,emo,chatgpt_txt):
        vis_rep=self.vis_model(videos,emo)
        batch_size = chatgpt_txt.shape[0]
        segs = torch.randn(batch_size,1,256).cuda()       
        for i in range(8):
            es = self.segment_lstm(chatgpt_txt[:,i,:,:])
            es=torch.unsqueeze(es,1)
            if i==0:
                segs=es
            else:
                segs=torch.cat([segs,es],dim=1)#batch*8*256
        if len(segs.shape)!=3:
            segs=torch.unsqueeze(segs,0)
            segs = segs.permute(0,2,1)                     
        output, _=self.seq_lstm(segs)
        att1 = self.linear1(output)#batch*15*1
        att1 = torch.transpose(att1,1,2) #
        clip_rep = torch.bmm(att1,output) #
        clip_rep=torch.squeeze(clip_rep,dim=1)
        clip_rep=self.fc_text(clip_rep)
        rep=torch.cat([vis_rep,clip_rep],dim=1)
        rep = self.fc(rep)  #(batch, 256)
        tr = self.fc2(rep)   #(b

        return tr





    
