import torch
from torch.nn import Linear, TransformerEncoder, TransformerEncoderLayer
import torch.nn.functional as F


# TODO: Add support for batched inputs
class MultiheadAttention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, bias=False):
        super(MultiheadAttention, self).__init__()

        self.query_layer = Linear(embed_dim, embed_dim * num_heads, bias=bias)
        self.key_layer = Linear(
            embed_dim if kdim is None else kdim, embed_dim * num_heads, bias=bias
        )
        self.value_layer = Linear(
            embed_dim if vdim is None else vdim, embed_dim * num_heads, bias=bias
        )

        if num_heads > 1:
            self.output_layer = Linear(embed_dim * num_heads, embed_dim, bias=bias)

        self.num_heads = num_heads

    def forward(self, query, key, value, mask=None):
        """

        Currently only supports unbatched inputs where query is shape (n_q, d), key is shape (n_v, d), and value is shape (n_v, d), and mask is shape (n_q, n_v).

        """

        if key.size(0) != value.size(0):
            raise ValueError("Number of keys and values must match")

        # (n_q, d) -> (n_q, h * d)
        q = self.query_layer(query)

        # (n_v, d) -> (n_v, h * d)
        k = self.key_layer(key)
        v = self.value_layer(value)

        # reshape queries (n_q, h * d) -> (h, n_q, d)
        q = q.view(self.num_heads, q.size(0), -1)

        # reshape keys and values (n_v, h * d) -> (h, n_v, d)
        k = k.view(self.num_heads, k.size(0), -1)
        v = v.view(self.num_heads, v.size(0), -1)

        # (h, n_q, d) @ (h, d, n_v) -> (h, n_q, n_v)
        adj = torch.einsum("hqd,hdk->hqk", q, k.transpose(1, 2))

        # scale with number of keys
        adj = adj / (k.size(-1) ** 0.5)

        # mask out the adj matrix
        if mask is not None:
            adj = adj.masked_fill(mask == 0, float("-inf"))

        # softmax
        adj = torch.softmax(adj, dim=-1)

        # (h, n_q, n_v) @ (h, n_v, d) -> (h, n_q, d)
        x = torch.einsum("hqv,hvd->hqd", adj, v)

        # reshape (h, n_q, d) -> (n_q, h * d)
        x = x.view(x.size(1), -1)

        # combine heads (n_q, h * d) -> (n_q, d)
        if self.num_heads > 1:
            x = self.output_layer(x)

        return x


class CustomTransformerEncoderLayer(torch.nn.Module):
    def __init__(self, d_model=512, d_ff=2048, num_heads=8, bias=False):
        super(CustomTransformerEncoderLayer, self).__init__()

        self.mha = MultiheadAttention(d_model, num_heads, bias=bias)

        self.ffn = torch.nn.Sequential(
            Linear(d_model, d_ff, bias=bias),
            torch.nn.ReLU(),
            Linear(d_ff, d_model, bias=bias),
        )

    def forward(self, x, mask=None):
        x = x + self.mha(x, x, x, mask=mask)
        x = F.layer_norm(x, x.size()[1:])
        x = x + self.ffn(x)
        x = F.layer_norm(x, x.size()[1:])
        return x


class CustomTransformerEncoder(torch.nn.Module):
    def __init__(self, num_layers=6, d_model=512, d_ff=2048, num_heads=8, bias=False):
        super(CustomTransformerEncoder, self).__init__()

        self.layers = torch.nn.ModuleList(
            [
                CustomTransformerEncoderLayer(
                    d_model=d_model, d_ff=d_ff, num_heads=num_heads, bias=bias
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x


class CustomTransformerDecoderLayer(torch.nn.Module):
    def __init__(self, d_model=512, d_ff=2048, num_heads=8, bias=False):
        super(CustomTransformerDecoderLayer, self).__init__()

        self.mha_enc_dec = MultiheadAttention(d_model, num_heads, bias=bias)
        self.mha_dec_dec = MultiheadAttention(d_model, num_heads, bias=bias)

        self.ffn = torch.nn.Sequential(
            Linear(d_model, d_ff, bias=bias),
            torch.nn.ReLU(),
            Linear(d_ff, d_model, bias=bias),
        )

    def forward(self, x, enc, mask=None):
        x = x + self.mha_enc_dec(x, enc, enc, mask=mask)
        x = F.layer_norm(x, x.size()[1:])
        x = x + self.mha_dec_dec(x, x, x)
        x = F.layer_norm(x, x.size()[1:])
        x = x + self.ffn(x)
        x = F.layer_norm(x, x.size()[1:])
        return x


class CustomTransformerDecoder(torch.nn.Module):
    def __init__(self, num_layers=6, d_model=512, d_ff=2048, num_heads=8, bias=False):
        super(CustomTransformerDecoder, self).__init__()

        self.layers = torch.nn.ModuleList(
            [
                CustomTransformerDecoderLayer(
                    d_model=d_model, d_ff=d_ff, num_heads=num_heads, bias=bias
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, enc, mask=None):
        for layer in self.layers:
            x = layer(x, enc, mask=mask)
        return x


class CustomTransformer(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers=6,
        d_model=512,
        d_ff=2048,
        num_heads=8,
        bias=False,
        encoder_only=False,
    ):
        super(CustomTransformer, self).__init__()

        self.encoder_in = Linear(in_channels, d_model, bias=bias)
        self.encoder = CustomTransformerEncoder(
            num_layers=num_layers,
            d_model=d_model,
            d_ff=d_ff,
            num_heads=num_heads,
            bias=bias,
        )

        if encoder_only:
            self.decoder_in = Linear(d_model, d_model, bias=bias)
            self.decoder = CustomTransformerDecoder(
                num_layers=num_layers,
                d_model=d_model,
                d_ff=d_ff,
                num_heads=num_heads,
                bias=bias,
            )

        self.out = Linear(d_model, out_channels, bias=bias)
        self.encoder_only = encoder_only

    def forward(self, x_encoder, x_decoder=None, mask=None):
        x_encoder = self.encoder_in(x_encoder)
        x = self.encoder(x_encoder, mask=mask)
        if not self.encoder_only:
            x_decoder = self.decoder_in(x_decoder)
            x = self.decoder(x_decoder, x, mask=mask)
        return x


class Transformer(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers=6,
        d_model=512,
        d_ff=2048,
        num_heads=8,
        dropout=0.1,
        bias=True,
    ):
        super(Transformer, self).__init__()

        self.encoder_in = Linear(
            in_channels,
            d_model,
            bias=bias,
        )
        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(
                nhead=num_heads,
                d_model=d_model,
                dim_feedforward=d_ff,
                bias=bias,
                dropout=dropout,
                batch_first=True,
            ),
            num_layers,
        )

        self.out = Linear(
            d_model,
            out_channels,
            bias=bias,
        )

        self.dropout = dropout

    def forward(self, x, mask=None):
        if x.dim() == 2:
            unsqueezed = True
            x = x.unsqueeze(0)
        else:
            unsqueezed = False

        x = self.encoder_in(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.encoder(x, mask=mask)
        x = self.out(x)

        if unsqueezed:
            x = x.squeeze(0)

        return x


class BlockAttention(torch.nn.Module):
    """

    Original Transformer scaled dot-product attention mechanism is:

        X := softmax( Q K^T / sqrt(d_k) ) V := A V

    To match the softmax normalization, for Block Attention:

        X := softmax(S) softmax(B) softmax(S^T) V := A V

    where softmax(S^T) means node-to-block assignments sum to 1, and softmax(S) means block-to-node assignments sum to 1.

    Multi-head attention focuses on stability of the attention outputs, but not of the attention weights.
    Block Attention additionally focuses on stability of the inferred graph using a mechanism similar to multi-head attention, where multiple S and B matrices are concatenated followed by a linear transformation to produce the final S and B matrices.
    Equivalently, the final S and B matrices can be sum of linear transformations of S and B matrices (see ["Are Transformers universal approximators of sequence-to-sequence functions?"](https://openreview.net/forum?id=ByxRM0Ntvr)).
    This could produce more stable clusters S also in unsupervised settings, matching the reverse Block Autoencoder.
    Could this also be used to improve existing graph (neural network) clustering algorithms?

    """

    def __init__(self, embed_dim, num_heads, bias=False):
        super(BlockAttention, self).__init__()

        self.clusters_layer = Linear(embed_dim, embed_dim * num_heads, bias=bias)
        self.blocks_in_layer = Linear(embed_dim, embed_dim * num_heads, bias=bias)
        self.blocks_out_layer = Linear(embed_dim, embed_dim * num_heads, bias=bias)

        self.values_layer = Linear(embed_dim, embed_dim * num_heads, bias=bias)
