
import torch.nn as nn
import torch.nn.functional as F
from models.layers.geometric.ndr_geometric import NDRGeometric
import torch as T


class ndr_geometric_stack(nn.Module):
    def __init__(self, config):
        super(ndr_geometric_stack, self).__init__()

        self.hidden_size = config["hidden_size"]
        self.dropout = config["dropout"]
        self.config = config
        self.train_max_depth = config["train_max_depth"]
        self.test_max_depth = config["test_max_depth"]
        self.EncoderStack = NDRGeometric(config=config)

    # %%
    def forward(self, sequence, input_mask):
        """
        N = Batch Size
        S = Sequence Size
        """
        N, S, D = sequence.size()
        input_mask = input_mask.view(N, S)

        sequence = F.dropout(sequence, p=self.dropout, training=self.training)

        if self.training:
            L = self.train_max_depth
        else:
            L = self.test_max_depth


        PAD = T.zeros(N, 1, D).float().to(sequence.dim)
        SOS = PAD.clone()
        EOS = PAD.clone()
        sequence = T.cat([SOS, sequence * input_mask.unsqueeze(-1), EOS], dim=1)
        assert sequence.size() == (N, S+2, D)
        input_mask = T.cat([T.ones(N, 2).float().to(sequence.device), input_mask], dim=1)
        assert input_mask.size() == (N, S+2)

        penalty = None
        for t in range(L):
            sequence = self.EncoderStack(sequence=sequence,
                                         input_mask=input_mask)

        global_state = sequence[:, 0, :]
        sequence = sequence[:, 1:-1, :]
        input_mask = input_mask[:, 2:, :]
        input_mask = input_mask.view(N, S, 1)

        return {"global_state": global_state, "sequence": sequence,
                "input_mask": input_mask, "aux_loss": None}