from temporal_encoder import TEncoder
from spatial_encoder import SEncoder
import torch
from torch import nn

class TSEncoder(nn.Module):
    def __init__(self, input_dims, seq_lengths, output_dims, hidden_dims=64, depth=10, p=0.5, mask_mode='binomial'):
        super().__init__()
        self.input_dims = input_dims
        self.seq_lengths = seq_lengths
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.tencoder = TEncoder(self.input_dims, self.output_dims, self.hidden_dims, depth=depth, mask_mode=mask_mode)
        self.sencoder = SEncoder(self.seq_lengths, self.output_dims, self.hidden_dims, p=p, mask_mode=mask_mode)
        self.output_fc1 = nn.Linear((self.output_dims)*2, self.hidden_dims)
        self.output_fc2 = nn.Linear(self.hidden_dims, self.output_dims)
        self.repr_dropout = nn.Dropout(p=0.1)

    def forward(self, x, mask=None):  # x: B cut_L V 
        t_x = self.tencoder(x, mask) # B cut_L Zo

        s_x = self.sencoder(x, mask) # B V Zo
        s_x = torch.bmm(x, s_x) # B cut_L Zo

        x = torch.cat((t_x, s_x), 2).clone().detach().requires_grad_(True)
        x = self.output_fc1(x)
        x = self.output_fc2(x)
        return x 