import torch
from einops import rearrange
from torch import nn, Tensor
from torch.nn import Sigmoid

import csiva.model.position_encoding as pe


class CausalDecoder(nn.Module):
    def __init__(
        self, max_seq_length: int, embed_dim: int, depth: int, num_heads: int, dim_feedforward: int,  dropout: float = 0.
    ):
        """TODO: docstring
        """
        super().__init__()
        half_embed_dim = int(embed_dim/2)

        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(embed_dim, num_heads, dim_feedforward, dropout, batch_first=True), depth
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 1),
            Sigmoid()
        )

        self.to_embedding = nn.Sequential(
            nn.Linear(1, half_embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(half_embed_dim, half_embed_dim),
            nn.Dropout(dropout)
        )

        # Identity embedding layer 
        self.identity_embedding = pe.NonlinearPositionalEncoding(half_embed_dim, max_seq_length)

    def forward(self, adj: Tensor, m: Tensor, tgt_mask: Tensor):
        """
        Decode input autoregressively into an adjacency matrix.
        :param adj: input adjacency matrix for autoregression. Dimension b x d^2 with d being the number of nodes
        and batch size b.
        :param m: Memory tensor from encoder. Dimension b x d x e with batch (b), num samples (n), num nodes (d),
        embedding dimension (e)
        :param tgt_mask: Mask for autoregressive behaviour. Masks entries of adj. Dimension b x d^2 x d^2 with d being
        the number of nodes.
        :return: output adjacency matrix. Dimensions b x d^2 with d being the number of nodes.
        """
        # Add dimension for embedding
        x = torch.unsqueeze(adj, -1)
        # Embed adjacency entries
        x = self.to_embedding(x)

        # Concatenate positional embedding
        x = self.identity_embedding(x)

        # Apply transformer units
        x = self.transformer(x, m, tgt_mask)
        # Transform with MLP head to outputs in (0, 1)
        x = self.mlp_head(x)
        # Have shape B x D^2 x 1. Change to B x D^2
        return rearrange(x, 'b dd e -> b (dd e)')

