from typing import List, Literal, Optional, Tuple
from omegaconf import DictConfig
import torch.nn as nn
from torch import Tensor
from torch_geometric import nn as pygnn
from einops import rearrange

from src.attention import SDPAttention
from src.attention import ATTN_TYPE

NormType = Literal["ln", "bn"]


class SaneBatchNorm1d(nn.Module):
    """
    A sane wrapper for BatchNorm1d that handles both 2D and 3D inputs.

    BatchNorm1d expects inputs in (N, C) or (N, C, L) format, but in transformer
    architectures we often have (N, L, C) format. This wrapper handles the conversion
    automatically. (N=batch, C=channels, L=sequence_length)
    """

    def __init__(self, dim: int):
        super().__init__()
        self.bn = nn.BatchNorm1d(dim)

    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 2:
            return self.bn(x)
        elif x.ndim == 3:
            x = rearrange(x, "N L C -> N C L")
            x = self.bn(x)
            x = rearrange(x, "N C L -> N L C")
            return x
        else:
            raise ValueError(f"Unsuported shape {x.shape=}")


def get_norm(dim: int, norm_type: NormType) -> nn.Module:
    if norm_type == "ln":
        return nn.LayerNorm(dim)
    elif norm_type == "bn":
        return SaneBatchNorm1d(dim)
    else:
        raise ValueError(f"Unsupported normalization type: {norm_type}")


class FeedForward(nn.Module):
    """Feedforward MLP with input normalization"""

    def __init__(self, dim, hidden_dim=None, mult=4, dropout=0.2):
        r"""
        Args:
            dim (int): Input dimension
            hidden_dim (int, optional): Hidden dimension for the feedforward layers. If None and dim_out is None,
                                      will be dim * mult
            mult (int, optional): Multiplier for hidden dimension if hidden_dim not specified. Default: 4
            dropout (float, optional): Dropout probability. Default: 0.2
        """
        super().__init__()
        hidden_dim = hidden_dim or (dim * mult)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(hidden_dim, dim),
        )

        self.apply(self.init_weights)

    def init_weights(self, m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


class SelfAttentionBlock(nn.Module):

    def __init__(
        self,
        feat_dim: int,
        pos_dim: int,
        heads: int,
        ffn_mult: int = 4,
        attn_dropout: float = 0.0,
        bias: bool = True,
        attn_type: ATTN_TYPE = "self",
        norm_type: NormType = "ln",
        cfg: Optional[DictConfig] = None,  # any extra config
    ):
        super().__init__()
        dim = pos_dim + feat_dim
        self.norm_attn = get_norm(dim, norm_type)
        self.attn = SDPAttention(
            dim=dim,
            heads=heads,
            dropout=attn_dropout,
            bias=bias,
            attn_type=attn_type,
        )

        dropout_p = cfg.model.ffn_dropout
        self.norm_ffn = get_norm(dim, norm_type)
        self.ffn = FeedForward(dim=dim, dropout=dropout_p, mult=ffn_mult)
        self.dropout = nn.Dropout(dropout_p)

    def forward(
        self,
        x: Tensor,
        seqlen: Optional[List[int]] = None,
        edge_index: Optional[Tensor] = None,
    ) -> Tensor:
        y = x + self.dropout(
            self.attn(
                self.norm_attn(x),
                seqlen_x=seqlen,
            )
        )
        y = y + self.dropout(self.ffn(self.norm_ffn(y)))
        return y


class SelfAttentionPostNormBlock(nn.Module):

    def __init__(
        self,
        feat_dim: int,
        pos_dim: int,
        heads: int,
        ffn_mult: int = 4,
        attn_dropout: float = 0.0,
        bias: bool = True,
        attn_type: ATTN_TYPE = "self",
        norm_type: NormType = "ln",
        cfg: Optional[DictConfig] = None,  # any extra config
    ):
        super().__init__()
        dim = pos_dim + feat_dim
        self.attn = SDPAttention(
            dim=dim,
            heads=heads,
            dropout=attn_dropout,
            bias=bias,
            attn_type=attn_type,
        )

        dropout_p = cfg.model.ffn_dropout
        self.ffn = FeedForward(dim=dim, dropout=dropout_p, mult=ffn_mult)
        self.dropout = nn.Dropout(dropout_p)

        self.norm_attn = get_norm(dim, norm_type)
        self.norm_ffn = get_norm(dim, norm_type)

    def forward(
        self,
        x: Tensor,
        seqlen: Optional[List[int]] = None,
        edge_index: Optional[Tensor] = None,
    ) -> Tensor:
        y = x + self.dropout(self.attn(x=x, seqlen_x=seqlen))
        y = self.norm_attn(y)
        y = y + self.dropout(self.ffn(y))
        y = self.norm_ffn(y)
        return y


class GCNSelfAttentionPostNorm2Block(nn.Module):

    def __init__(
        self,
        feat_dim: int,
        pos_dim: int,
        heads: int,
        ffn_mult: int = 4,
        attn_dropout: float = 0.0,
        bias: bool = True,
        norm_type: NormType = "ln",
        cfg: Optional[DictConfig] = None,  # any extra config
    ):
        super().__init__()
        dim = pos_dim + feat_dim
        self.attn = SDPAttention(
            dim=dim,
            heads=heads,
            dropout=attn_dropout,
            bias=bias,
            attn_type="self",
        )

        self.gcn = pygnn.GCNConv(dim, dim)

        dropout_p = cfg.model.ffn_dropout
        self.ffn = FeedForward(dim=dim, dropout=dropout_p, mult=ffn_mult)
        self.dropout = nn.Dropout(dropout_p)

        self.norm_attn = get_norm(dim, norm_type)
        self.norm_gcn = get_norm(dim, norm_type)
        self.norm_ffn = get_norm(dim, norm_type)

    def forward(
        self,
        x: Tensor,
        seqlen: Optional[List[int]] = None,
        edge_index: Optional[Tensor] = None,
    ) -> Tensor:

        yattn = self.attn(x, seqlen_x=seqlen)
        ygcn = self.gcn(x, edge_index)

        y = +self.norm_attn(x + self.dropout(yattn)) + self.norm_gcn(x + self.dropout(ygcn))
        y = y + self.dropout(self.ffn(y))
        y = self.norm_ffn(y)
        return y
