from torch import nn
from .transformer import TransformerEn, TransformerEn_aux
from .transformer import PositionalEncoding

class TransformerModel(nn.Module):
    def __init__(self, in_channel, d_model, nhead =8, num_encoder_layers=6, dim_feedforward=2048,
                       dropout =0.1, activation = "relu", normalize_before = False, max_len=2000):
        super().__init__()
        N_steps = d_model // 2
        self.pe_layer = PositionalEncoding(d_model=d_model, dropout=0.1, max_len= max_len)

        self.transformer = TransformerEn(
                                        d_model = d_model,
                                        nhead = nhead,
                                        num_encoder_layers = num_encoder_layers,
                                        dim_feedforward = dim_feedforward,
                                        dropout = dropout,
                                        activation = activation,
                                        normalize_before = normalize_before,
                                    ) #include norm at the end of transformer in encoder layer
        if in_channel != d_model:
            self.proj_out = nn.Linear(in_features= d_model, out_features= in_channel, bias=True)
        else:
            self.proj_out = None

    def forward(self, x, *args, **kwargs): #x[b l c], mask [b w]
        if 'mask' in kwargs:
            mask = kwargs['mask'] # mask
            assert x.shape[:-1] == mask.shape
        else:
            mask = None

        pos = self.pe_layer(x) #[l b c] [41,1,52]

        # mem = self.transformer(src=x, mask=None) #[b,l,512]

        mem = self.transformer(src=x, mask=mask, pos_embed = pos) #x[b w c], mask[b w]

        if self.proj_out != None:
            mem = self.proj_out(mem) #[b,l,2048]
        return mem

###========for aux
class TransformerModel_aux(nn.Module):
    def __init__(self, in_channel, d_model, nhead =8, num_encoder_layers=6, dim_feedforward=2048,
                       dropout =0.1, activation = "relu", normalize_before = False, max_len=2000):
        super().__init__()
        N_steps = d_model // 2
        self.pe_layer = PositionalEncoding(d_model=d_model, dropout=0.1, max_len= max_len)

        self.transformer = TransformerEn_aux(
                                        d_model = d_model,
                                        nhead = nhead,
                                        num_encoder_layers = num_encoder_layers,
                                        dim_feedforward = dim_feedforward,
                                        dropout = dropout,
                                        activation = activation,
                                        normalize_before = normalize_before,
                                    ) #include norm at the end of transformer in encoder layer
        if in_channel != d_model:
            self.proj_out = nn.Linear(in_features= d_model, out_features= in_channel, bias=True)
        else:
            self.proj_out = None

    def forward(self, x, *args, **kwargs): #x[b l c], mask [b w]
        if 'mask' in kwargs:
            mask = kwargs['mask'] # mask
            assert x.shape[:-1] == mask.shape
        else:
            mask = None

        pos = self.pe_layer(x) #[l b c] [41,1,52]

        # mem = self.transformer(src=x, mask=None) #[b,l,512]

        mem, out_list = self.transformer(src=x, mask=mask, pos_embed = pos) #x[b w c], mask[b w]

        if self.proj_out != None:
            mem = self.proj_out(mem) #[b,l,2048]
        return mem, out_list

