import torch
import torch.nn as nn

from torch.utils.checkpoint import checkpoint

class LSTMWrapper(nn.Module):
    def __init__(self, lstm):
        super(LSTMWrapper, self).__init__()
        self.lstm = lstm

    def forward(self, x):
        def lstm_forward(x):
            output, _ = self.lstm(x)
            return output
        output = lstm_forward(x)
        return output

class TransformerDecoderLayerWrapper(nn.Module):
    def __init__(self, layer):
        super(TransformerDecoderLayerWrapper, self).__init__()
        self.layer = layer

    def forward(self, x):
        def transformer_forward(x):
            return self.layer(x, x)
        
        return transformer_forward(x)
    
class SeqClassWrapper(nn.Module):
    def __init__(self, lstm, lengths):
        super(SeqClassWrapper, self).__init__()
        self.lstm = lstm
        self.lengths = lengths

    def forward(self, x):
        new_lengths = self.lengths.clone()
        packed_input = nn.utils.rnn.pack_padded_sequence(x, new_lengths.cpu(), batch_first = True, enforce_sorted = False)
        packed_output, _ = self.lstm(packed_input)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        lengths = self.lengths.unsqueeze(1).unsqueeze(2)
        out = torch.gather(out, 1, (lengths - 1).expand(-1, -1, out.shape[-1])).squeeze(1)
        return out
    
class TransEncWrapper(nn.Module):
    def __init__(self, layer, mask):
        super(TransEncWrapper, self).__init__()
        self.layer = layer
        self.mask = mask

    def forward(self, x):
        x = self.layer(x, src_key_padding_mask = self.mask)
        non_pad_mask = (~self.mask).unsqueeze(-1)
        x = x * non_pad_mask.float()
        sum_output = x.sum(dim = 1)
        non_pad_count = non_pad_mask.sum(dim = 1)
        x = sum_output / non_pad_count.clamp(min = 1)
        return x