from typing import List

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn, Tensor

import csiva.model.position_encoding as pe
from csiva.model.alternate_attention import AlternatingAttentionStack


class CausalEncoder(nn.Module):
    """
    Encoder part of the model as described in 'Learning to induce causal structure' by Ke et al.

    Parameters
    ----------
    num_nodes : int
        Number of rows in teh dataset. Correspond to the maximum number of nodes in a graph.
    batch_size: int
        Batch size of the input. 
    d_model: int
        Dimension of the key, query, value input embeddings.
    dim_feedforward: int
        Hidden dimension of the MLP after MHSA in the encoder layer.
    num_encoder_layers: int
        Number of stacked alternate attention blocks.
    num_heads: int
        Number of heads in multi-head attention in the alternate attention blocks.
    rff_depth: int
        Number of feed-forward layers in the MLP of the MHSA in alternate attention.
    p_dropout: float
        Dropout probability in multi-head attention in the alternate attention blocks.
    eps_layer_norm: float, default 0.00005
        LayerNorm epsilon in multi-head attention in the alternate attention blocks.
        Required for numerical stability.
    encoder_layer_type: str, default "custom"
        Specify which class to use for the encoder layer. Use "custom" for csiva.model.MHSA,
        or "torch" for torch.nn.TransformerEncoderLayer.
    encoder_summary_type: str, default "sdp"
        Specify how to compute the encoder summary. Use "sdp" for 
        torch.nn.functionals.scaled_dot_product, or "mhsa" for torch.nn.MulitHeadSelfAttention.
    rff_depth: int, default 1
        Number of feed-forward layers in the MLP of the encoder layer in alternate attention.
    """

    def __init__(
        self,
        num_nodes: int,
        d_model: int,
        dim_feedforward: int,
        num_encoder_layers: int,
        num_heads: int,
        p_dropout: float = 0.,
        eps_layer_norm: float = 0.00005,
        encoder_layer_type: str = "torch",
        encoder_summary_type: str = "sdp",
        rff_depth: int = 1
    ):
        if dim_feedforward %2 != 0:
            raise ValueError("dim_feedforward value is an odd number. Please provide an even value.")
        
        super().__init__()
        self.num_nodes = num_nodes
        self.embed_dim = d_model
        self.dim_feedforward = dim_feedforward 
        self.num_encoder_layers = num_encoder_layers
        self.num_heads = num_heads
        self.eps_layer_norm = eps_layer_norm
        self.half_embed_dim = int(d_model/2)
        self.encoder_summary_type = encoder_summary_type

        # *** Model construction ***

        # * Embedding layer (learnable) *
        # Embed each [node, sample] entry to an embedding vector of shape `E`.
        self.input_embedding = nn.ModuleList([
            nn.Linear(1, self.half_embed_dim)
            for _ in range(num_nodes)
        ])

        self.identity_embedding = pe.NonlinearPositionalEncoding(self.half_embed_dim, self.num_nodes)

        # Build the encoder stacking alternate attention layers
        self.encoder = AlternatingAttentionStack(
            num_encoder_layers, d_model, dim_feedforward, num_heads,
            eps_layer_norm, p_dropout, encoder_layer_type, rff_depth
        )

        if encoder_summary_type == "mhsa":
            self.summary_layer = nn.MultiheadAttention(
                self.embed_dim,
                self.num_heads,
                p_dropout,
                batch_first=True
            )
        elif encoder_summary_type != "sdp":
            raise ValueError("The accepted values for encoder_summary_type input arguments are 'sdp', 'mhsa"\
                             f" Got instead {encoder_summary_type}.")
        

    def forward(
            self,
            dataset: Tensor,
            key_padding_mask: Tensor = None
    ):
        """
        Forward input dataset.
        :param dataset: Dataset of dimension b x n x d with num nodes (d), num samples (n).
        :return: node-wise embeddings for each dataset. Dimensions d x e with
        num nodes (d), embedding dimension (e)
        """
        dataset = dataset.transpose(2, 1) # Reshape to (batch_size, num_nodes, num_samples)
        dataset = dataset.unsqueeze(dim=-1) # Add the embeddings dimension

        # Embed every dataset[node, sample] with node-specific embedding. Output shape B x D x N X E/2
        X = torch.stack([
            embed_layer(dataset[:, node]) for node, embed_layer in enumerate(self.input_embedding)
        ], dim=1)

        # Concatenate positional encodings. Output shape B x D x N X E
        X = self.identity_embedding(X)

        # Add summary entry      
        X, datapoints_mask = self._add_summary_dim(X)  

        # Apply the encoder. Output shape is B x D x N+1 x E
        X = self.encoder(X, key_padding_mask, datapoints_mask)

        # Encoder summary. Output is B x D x E
        summary = self._encoder_summary(X)

        # Output is B x D x E, and ready to be fed to a decoder.
        return summary

    def _encoder_summary(
            self,
            X: Tensor
    ):
        """Output the encoder summary as weighted average of the X samples.

        Parameters
        ----------
        X : Tensor of shape (num_nodes, num_samples+1, embed_dim)
            Tensor transformed by multiple stacks of alternate attention

        Returns
        -------
        Tensor of shape (batch_size, num_attributes, embed_dim).
        Weighted average of X entries across the 'dataset_samples' dimension.
        The weights are defined as a vector of shape (dataset_samples, ), formed
        by key-value attention where the query is the flattened version of X[:, -1]
        and the key values are flattened versions of X[:, :-1]
        """
        if self.encoder_summary_type == "sdp":
            query = X[:, :, -1].unsqueeze(2)
            keys = X[:, :, :-1]
            values = X[:, :, :-1]

            attn_output = F.scaled_dot_product_attention(
                query,
                keys,
                values
            )
        
            attn_output = attn_output.squeeze(2)
        
        elif self.encoder_summary_type == "mhsa":
            b, d, n, e = X.shape
            X = X.reshape((b*d, n, e))

            # Select query as the summary sample
            query = X[:, -1, :].unsqueeze(1)
            keys = X[:, :-1, :]
            values = X[:, :-1, :]

            # Summarize info across all samples, per each node
            attn_output, _ = self.summary_layer(
                query,
                keys,
                values
            )

            # Reshape to b, d, e: TODO: test that it works as expected
            attn_output = attn_output.reshape((b, d, e))

        return attn_output
            


    def _add_summary_dim(
            self,
            X: Tensor
    ):
        """Add column to deposit summary info to the second dimension of the input tensor.

        Parameters
        ----------
        X : Tensor
            Dataset tensor of shape (batch_size, num_nodes, num_samples, embed_dim).
           `X` is modified by adding a zero vector of shape `num_nodes` to axis=2.

        Returns
        -------
        X: Tensor
        datapoints_mask: Tensor
            Boolean or float mask tensor to prevent attending to the added summary entries.            
        """
        B, D, N, _ = X.shape

        # Create mask (attn weights are multiplied by the mask values)
        datapoints_mask = torch.zeros((B, D, N+1), dtype=X.dtype, device=X.device) # exp^0 = 1
        datapoints_mask[:, :, -1] = float("-inf") # Mask encoder summary entry: exp(-inf)=0

        # Add dimension
        summary_entry = torch.zeros((B, D, 1, self.half_embed_dim*2), device=X.device)
        X = torch.cat((X, summary_entry), dim=2)
        return X, datapoints_mask
