from torch import nn, Tensor
import torch.nn.functional as F
from abc import ABC, abstractmethod
from src.models.gcn_bipartite import GCNConv
from typing import List, Optional, Tuple, Dict
from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn import SAGEConv, GATConv
from src.models.feature_lifter import FeatureLifter
from src.models.feature_projector import FeatureProjector


class BaseSSN(nn.Module, ABC):
    """Abstract base-class for *Sparse Simplicial Networks*.

    Parameters
    ----------
    edge_types:
        List of relation type triples *(source_dim, rel_type, target_dim)*.
    hidden_sizes:
        Dictionary mapping each set of n-simplices to its embedding dimension.
    n_classes:
        Number of output classes.
    num_layers:
        Message-passing depth.
    conv_type:
        Choice of underlying GNN operator – one of ``{"SAGE", "GAT", "GCN"}``.
    jumping_knowledge:
        ``None`` | "cat" | "lstm" | "max" – passed to
        :class:`torch_geometric.nn.JumpingKnowledge`.
    dropout:
        Feature dropout probability after each layer.
    normalize:
        If *True*, apply L2 feature normalisation at the end of each layer.
    in_aggr / out_aggr:
        Reduction ops for *inner* (multiple relations into dst) and *outer*
        (multiple simplices into common dst) aggregation.
    lin_res:
        Enable per-layer residual MLPs.
    device:
        Accelerator identifier
    bn / ln:
        Enable BatchNorm / LayerNorm blocks, respectively.
    """

    def __init__(
        self,
        edge_types: List[Tuple[str, str, str]],
        hidden_sizes: Dict[str, int],
        n_classes: int,
        num_layers: int,
        conv_type: str = "SAGE",
        jumping_knowledge: Optional[str] = None,
        dropout: float = 0.0,
        normalize: bool = False,
        in_aggr: str = "sum",
        out_aggr: str = "sum",
        lin_res: bool = False,
        device: str = "gpu",
        bn: bool = False,
        ln: bool = False,
    ) -> None:
        super().__init__()

        # ---------------  Hyper-parameters & meta-data  ----------------------
        self.edge_types = edge_types
        self.hidden_sizes = hidden_sizes
        self.num_layers = num_layers
        self.conv_type = conv_type
        self.num_classes = n_classes
        self.device = device  # used for *lazy* module allocation

        self.dropout = dropout
        self.l2 = normalize
        self.inner_aggregation = in_aggr
        self.outer_aggregation = out_aggr
        self.jumping_knowledge = jumping_knowledge

        # Number of different *simplex dimensions* encountered in the complex
        self.mdim = int(max({s for s, _, _ in edge_types}.union({d for _, _, d in edge_types}))) + 1

        # ---------------------  Convolution layer  ---------------------------
        self.conv_type = conv_type
        layer_mapping = {"SAGE": SAGEConv, "GAT": GATConv, "GCN": GCNConv}
        if conv_type not in layer_mapping:
            raise ValueError(
                f"Unsupported conv_type '{conv_type}'. Choose from {set(layer_mapping)}."
            )
        self.ConvLayer = layer_mapping[conv_type]
        self.conv_kwargs = self._get_layer_kwargs(conv_type)

        # ---------------------  Optional blocks  -----------------------------
        self.lins_res: Optional[nn.ModuleDict] = (
            self._init_residual_mlp(hidden_sizes, device) if lin_res else None
        )
        self.layer_norms: Optional[nn.ModuleList] = (
            self._init_layer_norms(hidden_sizes, device) if ln else None
        )
        self.batch_norms: Optional[nn.ModuleList] = (
            self._init_batch_norms() if bn else None
        )

        # Projection & lifting of (higher-order) node features -----------------
        if self.mdim > 1 and "0" in hidden_sizes:
            self.feature_projector = FeatureProjector(
                hidden_sizes["0"],
                self.mdim - 1,
                device,
                self.ConvLayer,
                self.conv_kwargs,
            )
            self.feature_lifter = FeatureLifter(
                hidden_sizes["0"],
                dropout,
                self.mdim - 1,
                device,
                bn,
                ln,
                normalize,
                self.ConvLayer,
                self.conv_kwargs,
            )
        else:  # 1-simplicial (graph) case – no projection / lifting needed
            self.feature_projector = None
            self.feature_lifter = None

        # Jumping-knowledge readout & final classifier ------------------------
        self.jump, self.lin_nodes = self._init_jumping_knowledge(
            hidden_sizes, n_classes, num_layers, jumping_knowledge
        )

    # ------------------------------------------------------------------
    #  Module builders
    # ------------------------------------------------------------------
    def _init_residual_mlp(
        self, hidden_sizes: Dict[str, int], accelerator: str
    ) -> nn.ModuleDict:
        """One *lazy* linear per layer & dimension, used for ResNet-style skip."""
        lins = nn.ModuleDict({str(d): nn.ModuleList() for d in range(self.mdim)})
        for dim in range(self.mdim):
            for _ in range(self.num_layers):
                if str(dim) in hidden_sizes:
                    lins[str(dim)].append(nn.LazyLinear(hidden_sizes[str(dim)]).to(accelerator))
        return lins

    def _init_layer_norms(
        self, hidden_sizes: Dict[str, int], accelerator: str
    ) -> nn.ModuleList:
        norms: nn.ModuleList = nn.ModuleList()
        for layer_idx in range(self.num_layers):
            layer_norms = nn.ModuleDict()
            valid_dst = {
                dst for _, _, dst in self.edge_types if int(dst) <= layer_idx + 1
            }
            for dst in valid_dst:
                layer_norms[dst] = nn.LayerNorm(hidden_sizes[dst]).to(accelerator)
            norms.append(layer_norms)
        return norms

    def _init_batch_norms(self) -> nn.ModuleList:
        bns: nn.ModuleList = nn.ModuleList()
        for layer_idx in range(self.num_layers):
            layer_bns = nn.ModuleDict()
            valid_dst = {
                dst for _, _, dst in self.edge_types if int(dst) <= layer_idx + 1
            }
            for dst in valid_dst:
                layer_bns[dst] = nn.BatchNorm1d(self.hidden_sizes[dst]).to(self.device)
            bns.append(layer_bns)
        return bns

    def _init_jumping_knowledge(
        self,
        hidden_sizes: Dict[str, int],
        n_classes: int,
        num_layers: int,
        mode: Optional[str],
    ) -> Tuple[Optional[JumpingKnowledge], nn.Linear]:
        """Configure Jumping-Knowledge readout and classifier head."""

        if mode is None:
            if "0" in hidden_sizes:
                return None, nn.Linear(hidden_sizes["0"], n_classes)
            else:
                return None, None

        if mode == "cat":
            input_dim = hidden_sizes["0"] * (
                num_layers - 1 if self.mdim == 1 else num_layers + 1 - self.mdim
            )
        else:  # "max" | "lstm"
            input_dim = hidden_sizes["0"]

        jump = JumpingKnowledge(
            mode=mode, channels=hidden_sizes["0"], num_layers=num_layers + 1 - self.mdim
        )
        return jump, nn.Linear(input_dim, n_classes)

    # ------------------------------------------------------------------
    #  Utilities
    # ------------------------------------------------------------------
    @staticmethod
    def _get_layer_kwargs(conv_type: str) -> Dict[str, int | bool]:
        """Return operator-specific keyword arguments."""
        if conv_type == "GAT":
            return {"concat": True, "add_self_loops": False, "bias": False, "heads": 1}
        if conv_type == "SAGE":
            return {"root_weight": False}
        if conv_type == "GCN":
            return {"add_self_loops": False}
        raise RuntimeError("Unreachable")

    def _apply_layer_transforms(
        self,
        x_dict: Dict[str, Tensor],
        dropout: float,
        training: bool,
        layer_idx: int,
    ) -> Dict[str, Tensor]:
        """Common post-processing stack: (BN|LN) → ReLU → Dropout → L2-norm."""
        if self.batch_norms:
            x_dict = {
                k: (
                    self.batch_norms[layer_idx][k](v)
                    if k in self.batch_norms[layer_idx]
                    else v
                )
                for k, v in x_dict.items()
            }
        if self.layer_norms:
            x_dict = {
                k: (
                    self.layer_norms[layer_idx][k](v)
                    if k in self.layer_norms[layer_idx]
                    else v
                )
                for k, v in x_dict.items()
            }
        x_dict = {k: F.leaky_relu(v) for k, v in x_dict.items()}
        x_dict = {
            k: F.dropout(v, p=dropout, training=training) for k, v in x_dict.items()
        }
        if self.l2:
            x_dict = {k: F.normalize(v, p=2, dim=1) for k, v in x_dict.items()}
        return x_dict

    # ------------------------------------------------------------------
    #  Abstract interface ------------------------------------------------
    # ------------------------------------------------------------------
    @abstractmethod
    def _build_direction_experts(self) -> nn.Module:
        """Return the *message-passing* component of the network."""

    @abstractmethod
    def forward(self, *args, **kwargs):
        """Sub-class specific forward pass."""
