import torch
from copy import deepcopy
from torch import Tensor, nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from typing import Optional, List, Tuple, Dict
from torch_geometric.nn import global_mean_pool

# Project‑internal imports -----------------------------------------------------
from src.models.base_ssn import BaseSSN
from src.models.rel_routing import RelationRouter
from utils.utils import compute_multiclass_accuracy
from src.models.noise_scheduler import NoiseScheduler
from src.models.ssn_heteroconv import SSNHeteroConv, group


class DACClassifier(LightningModule):
    """Dynamical Activity Complexes Classifier."""

    def __init__(
        self,
        edge_types: List[Tuple[str, str, str]],
        hidden_sizes: List[int],
        n_classes: int,
        num_layers: int,
        dropout: float = 0.0,
        jumping_knowledge: Optional[str] = None,
        normalize: bool = False,
        learning_rate: float = 1e-3,
        wd: float = 5e-2,
        loss: nn.Module = nn.NLLLoss(),
        conv_type: str = "SAGE",
        device: str = "gpu",
        in_aggr: str = "sum",
        out_aggr: str = "mean",
        lin_res: bool = False,
        ln: bool = False,
        bn: bool = False,
        input_dims: Optional[Dict[str, int]] = None,
        routing: bool = False,
        k: Optional[int] = None,
    ) -> None:

        super().__init__()

        self.save_hyperparameters(ignore=["loss"])
        self.loss = loss
        self.routing = routing
        self.weight_decay = wd
        self.learning_rate = learning_rate
        self.noise_scheduler = NoiseScheduler()

        # ------------------------------------------------------------------
        #  Instantiate the underlying SSN or RoutingSSN model
        # ------------------------------------------------------------------

        common_args = dict(
            edge_types=edge_types,
            hidden_sizes=hidden_sizes,
            n_classes=n_classes,
            num_layers=num_layers,
            dropout=dropout,
            jumping_knowledge=jumping_knowledge,
            normalize=normalize,
            conv_type=conv_type,
            device=device,
            in_aggr=in_aggr,
            out_aggr=out_aggr,
            lin_res=lin_res,
            ln=ln,
            bn=bn,
        )

        if routing:
            self.model: BaseSSN = RSSN(
                **common_args, input_dims=input_dims, k=k  # type: ignore[arg-type]
            )
        else:
            self.model = SSN(**common_args)

    # ---------------------------------------------------------------------
    # Forward
    # ---------------------------------------------------------------------

    def forward(
        self,
        x_dict: Dict[str, Tensor],
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
        batch_dict: Dict[str, Tensor],
        noise_epsilon: Optional[float] = None,
    ) -> Tensor | Tuple[Tensor, Tensor]:

        if self.routing:
            return self.model(x_dict, edge_index_dict, batch_dict, noise_epsilon)

        return self.model(x_dict, edge_index_dict, batch_dict)

    # ------------------------------------------------------------------
    # Training / validation / test
    # ------------------------------------------------------------------

    def compute_loss(self, logits: Tensor, target: Tensor) -> Tensor:
        return self.loss(logits.squeeze(), target.squeeze())

    def shared_step(self, batch, stage):
        """
        Common logic for *train/val/test* steps.

        Args:
            batch: A batch of data. Expected to contain:
                - hetero_data: an object with attributes `x_dict`, `edge_index_dict`, and `y`.
                - (optionally) for each key in x_dict, a node batch indicator is expected as hetero_data[key].batch.
            stage: One of "train", "val", or "test".

        Returns:
            The computed loss for the batch.
        """

        hetero_data = batch  # Lightning's DataLoader returns our *HeteroData*

        target = torch.LongTensor(hetero_data.y).to(self.device)
        x_dict = hetero_data.x_dict
        edge_index_dict = hetero_data.edge_index_dict
        batch_dict = {k: hetero_data[k].batch for k in x_dict}

        # Forward -----------------------------------------------------------------
        if self.routing:
            noise_epsilon = self.noise_scheduler.get_noise(self.current_epoch)
            logits, routing_loss = self(
                x_dict, edge_index_dict, batch_dict, noise_epsilon
            )
            loss = self.compute_loss(logits, target) + routing_loss
        else:
            logits = self(x_dict, edge_index_dict, batch_dict)
            loss = self.compute_loss(logits, target)

        # Metrics -----------------------------------------------------------------
        acc = compute_multiclass_accuracy(logits, target).detach().cpu().item()
        log_kwargs = (
            {"prog_bar": True}
            if stage == "train"
            else {"prog_bar": True, "batch_size": 1}
        )

        self.log(f"{stage}_loss", loss, **log_kwargs)
        self.log(f"{stage}_acc", acc, **log_kwargs)
        return loss

    # Lightning -------------------------------------------------------------------
    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )


# -----------------------------------------------------------------------------
#                    Semi-Simplicial Neural Network
# -----------------------------------------------------------------------------


class SSN(BaseSSN):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.direction_experts = self._build_direction_experts()
        self.linear = nn.Linear(sum(self.hidden_sizes.values()), self.num_classes)

    def _build_direction_experts(self) -> nn.ModuleList:

        convs = nn.ModuleList()
        in_channels = (-1, -1) if self.conv_type == "GAT" else -1

        for _ in range(self.num_layers):
            conv = SSNHeteroConv(
                {
                    edge_type: self.ConvLayer(
                        in_channels=in_channels,
                        out_channels=self.hidden_sizes[edge_type[-1]],
                        **self.conv_kwargs,
                    )
                    for edge_type in self.edge_types
                }
            )
            convs.append(conv)
        return convs

    # ------------------------------------------------------------------
    #  Forward
    # ------------------------------------------------------------------
    def forward(
        self,
        x_dict: Dict[str, Tensor],
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
        batch_dict: Dict[str, Tensor],
    ) -> Tensor:
        """Message‑passing forward pass."""
        x_dict = deepcopy(x_dict)

        for layer_idx, conv in enumerate(self.direction_experts):
            # Intra‑layer message passing ------------------------------------------------
            tmp_dict: Dict[str, List[Tensor]] = {}
            inter_dict = conv(x_dict, edge_index_dict)

            for dst, rel_dict in inter_dict.items():
                for src in rel_dict:
                    tmp_dict.setdefault(dst, []).append(
                        group(rel_dict[src], self.inner_aggregation)
                    )

            # Outer aggregation & residual --------------------------------------------
            for dst, tensor_list in tmp_dict.items():
                aggregated = group(tensor_list, self.outer_aggregation)
                if self.lins_res:
                    aggregated += self.lins_res[dst][layer_idx](x_dict[dst])
                x_dict[dst] = aggregated

            # Activation / Dropout / BN / LN / L2--------------------------------------
            x_dict = self._apply_layer_transforms(
                x_dict, self.dropout, self.training, layer_idx
            )

        # ------------------------------------------------------------------
        #  Dynamical Activity Complex Redout
        # ------------------------------------------------------------------

        pooled = {k: global_mean_pool(x, batch_dict[k]) for k, x in x_dict.items()}

        dac_emb = torch.cat(
            [
                (
                    pooled_level
                    if pooled_level.shape[1] == self.hidden_sizes[k]
                    else torch.zeros(
                        (pooled_level.shape[0], self.hidden_sizes[k]),
                        device=self.device,
                    )
                )
                for k, pooled_level in pooled.items()
            ],
            dim=1,
        )

        logits = self.linear(dac_emb)
        return F.log_softmax(logits, dim=1)


# -----------------------------------------------------------------------------
#            Routing Semi-Simplicial Neural Network (R-SSN)
# -----------------------------------------------------------------------------


class RSSN(BaseSSN):
    def __init__(self, *args, **kwargs) -> None:
        self.k = kwargs.pop("k", None)
        self.input_dims = kwargs.pop("input_dims", None)
        super().__init__(*args, **kwargs)
        self.linear = nn.Linear(sum(self.hidden_sizes.values()), self.num_classes)
        self.topo_rels = self._extract_topological_relations()
        self.direction_experts = self._build_direction_experts()

    def _extract_topological_relations(self) -> Dict[str, List[Tuple[str, str, str]]]:
        """Group edge types by their *topological* relation (up/down)."""
        relations: Dict[str, List[Tuple[str, str, str]]] = {}
        for src, rel, dst in self.edge_types:
            prefix, *rest = rel.split("_")
            key = f"{src}_{prefix}"
            relations.setdefault(key, []).append((src, f"{prefix}_{rest[0]}", dst))
        return relations

    def _build_direction_experts(self) -> nn.ModuleList:
        """Create *RelationRouter* or per relation class."""
        layers: nn.ModuleList = nn.ModuleList()

        for layer_idx in range(self.num_layers):
            module_dict: Dict[str, nn.Module] = {}
            for topo_key, edge_group in self._topo_rels.items():
                src = topo_key.split("_")[0]
                in_dim = (
                    self.input_dims[src]
                    if layer_idx == 0 and src in self.input_dims
                    else self.hidden_sizes[src]
                )

                if len(edge_group) > 1:  # --- Mixture of Experts ---------------
                    experts = nn.ModuleList(
                        [
                            SSNHeteroConv(
                                {
                                    et: self.ConvLayer(
                                        in_channels=-1,
                                        out_channels=self.hidden_sizes[et[-1]],
                                        **self.conv_kwargs,
                                    )
                                },
                                unique=True,
                            )
                            for et in edge_group
                        ]
                    )
                    module_dict[topo_key] = RelationRouter(
                        experts=experts,
                        input_size=in_dim,
                        output_size=self.hidden_sizes[src],
                        device=self.device,
                        k=self.k,
                    )
                else:  # --------------------------------------------------------
                    et = edge_group[0]
                    module_dict[topo_key] = SSNHeteroConv(
                        {
                            et: self.ConvLayer(
                                in_channels=-1,
                                out_channels=self.hidden_sizes[et[-1]],
                                **self.conv_kwargs,
                            )
                        },
                        unique=True,
                    )
            layers.append(nn.ModuleDict(module_dict))

        return layers

        # ------------------------------------------------------------------
        #  Forward
        # ------------------------------------------------------------------

    def forward(
        self,
        x_dict: Dict[str, Tensor],
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
        batch_dict: Dict[str, Tensor],
        noise_epsilon: float | None = 1e-2,
    ) -> Tuple[Tensor, Tensor]:
        """Forward pass with (optional) noisy top‑k routing."""
        router_losses: List[Tensor] = []
        x_dict = deepcopy(x_dict)

        for layer_idx, layer_group in enumerate(self.direction_experts):
            tmp_dict: Dict[str, List[Tensor]] = {}

            # Iterate over topological groups -------------------------------------
            for topo_key, sub_layer in layer_group.items():
                dst = self._topo_rels[topo_key][0][-1]
                if isinstance(sub_layer, RelationRouter):
                    y_adj, loss = sub_layer(
                        x_dict,
                        edge_index_dict,
                        topo_key.split("_")[0],
                        self.training,
                        layer_idx,
                        noise_epsilon,
                    )
                    router_losses.append(loss)
                else:
                    y_adj = sub_layer(x_dict, edge_index_dict)
                tmp_dict.setdefault(dst, []).append(y_adj)

            # Aggregate & residual ------------------------------------------------
            for dst, tensors in tmp_dict.items():
                aggregated = group(tensors, self.outer_aggregation)
                if self.lins_res:
                    aggregated += self.lins_res[dst][layer_idx](x_dict[dst])
                x_dict[dst] = aggregated

            # Dropout / BN / LN ---------------------------------------------------
            x_dict = self._apply_layer_transforms(
                x_dict, self.dropout, self.training, layer_idx
            )

        # ------------------------------------------------------------------
        #  Dynamical Activity Complex Redout
        # ------------------------------------------------------------------

        pooled = {k: global_mean_pool(x, batch_dict[k]) for k, x in x_dict.items()}

        dac_emb = torch.cat(
            [
                (
                    pooled_level
                    if pooled_level.shape[1] == self.hidden_sizes[k]
                    else torch.zeros(
                        (pooled_level.shape[0], self.hidden_sizes[k]),
                        device=self.device,
                    )
                )
                for k, pooled_level in pooled.items()
            ],
            dim=1,
        )
        logits = self.linear(dac_emb)
        routing_loss = (
            torch.stack(router_losses).mean() if router_losses else torch.tensor(0.0)
        )
        return F.log_softmax(logits, dim=1), routing_loss
