from copy import deepcopy
from torch import nn, Tensor
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from typing import Dict, Tuple, Type, Optional
from src.models.ssn_heteroconv import SSNHeteroConv


# -----------------------------------------------------------------------------
#                              Feature Lifter
# -----------------------------------------------------------------------------


class FeatureLifter(nn.Module):
    """Lift **0‑simplices** (nodes) to higher‑order simplices.

    The module stacks *max_dim* heterogeneous convolution layers that connect
    a simplex of dimension *n* to its co‑face of dimension *n+1*.  After each
    convolution, a normalisation → activation → dropout → (optional) L2
    pipeline is applied.

    Parameters
    ----------
    hidden_dim:
        Size of the embedding for each n-simplices dimensions.
    dropout:
        Dropout probability applied after Leaky‑ReLU.
    max_dim:
        Highest simplex dimension in the complex minus **one** (because we lift
        *n → n+1*).  For a 2‑dimensional complex this would therefore be ``2``.
    device:
        Device string – used for *lazy* module placement (e.g. "cpu", "cuda").
    bn / ln:
        Enable *BatchNorm* / *LayerNorm* respectively.
    l2:
        Apply L2 feature normalisation after dropout.
    ConvLayer / conv_kwargs:
        The GNN operator class (``SAGEConv`` | ``GATConv`` | ``GCNConv``) and
        its instantiation kwargs.
    """

    def __init__(
        self,
        hidden_dim: int,
        dropout: float,
        max_dim: int,
        device: str,
        bn: bool,
        ln: bool,
        l2: bool,
        ConvLayer: Type[nn.Module],
        conv_kwargs: Dict[str, int | bool],
    ) -> None:  # pylint: disable=too-many-arguments
        super().__init__()

        self.dropout = dropout
        self.max_dim = max_dim
        self.l2 = l2

        # Normalisation stacks ------------------------------------------------
        self.batch_norms: Optional[nn.ModuleList] = (
            nn.ModuleList(
                [nn.BatchNorm1d(hidden_dim).to(device) for _ in range(max_dim)]
            )
            if bn
            else None
        )
        self.layer_norms: Optional[nn.ModuleList] = (
            nn.ModuleList([nn.LayerNorm(hidden_dim).to(device) for _ in range(max_dim)])
            if ln
            else None
        )

        # Convolution layers ---------------------------------------------------
        in_channels = (-1, -1) if ConvLayer is GATConv else -1
        self.conv_layers: nn.ModuleList = nn.ModuleList(
            [
                SSNHeteroConv(
                    {
                        (f"{i}", "c_a", f"{i + 1}"): ConvLayer(
                            in_channels=in_channels,
                            out_channels=hidden_dim,
                            **conv_kwargs,
                        )
                    },
                    unique=True,
                ).to(device)
                for i in range(max_dim)
            ]
        )

    # ------------------------------------------------------------------
    #  Internals
    # ------------------------------------------------------------------
    def _apply_transforms(self, x: Tensor, dim: int) -> Tensor:
        """(BN|LN) → Leaky‑ReLU → Dropout → L2‑normalisation."""
        if self.batch_norms:
            x = self.batch_norms[dim](x)
        if self.layer_norms:
            x = self.layer_norms[dim](x)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return F.normalize(x, p=2, dim=1) if self.l2 else x

    # ------------------------------------------------------------------
    #  Forward
    # ------------------------------------------------------------------
    def forward(
        self,
        x_dict: Dict[str, Tensor],
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
    ) -> Dict[str, Tensor]:
        """Lift 0‑simplices step‑by‑step up to ``max_dim``."""
        x_dict = deepcopy(x_dict)  # never mutate caller
        for dim in range(self.max_dim):
            dst = str(dim + 1)
            x_up = self.conv_layers[dim](x_dict, edge_index_dict)
            x_dict[dst] = self._apply_transforms(x_up, dim)
        return x_dict
