from torch import nn
from utils.utils import *
from sklearn.metrics import roc_auc_score
from pytorch_lightning import LightningModule
from src.models.rel_routing import RelationRouter
from src.models.noise_scheduler import NoiseScheduler
from copy import deepcopy
import torch.nn.functional as F
from src.models.base_ssn import BaseSSN
from src.models.ssn_heteroconv import SSNHeteroConv, group
import torch


class NodeClassifier(LightningModule):
    """
    Lightning wrapper class for SSN for node classification tasks.
    Parameters:
        edge_types (list): List of edge types in the heterogeneous graph.
        hidden_sizes (dict): Dictionary of hidden layer sizes for each layer.
        n_classes (int): Number of output classes.
        num_layers (int): Number of layers in the model.
        dropout (float): Dropout rate.
        jumping_knowledge (str): Type of jumping knowledge to use (None for no JK).
        normalize (bool): Whether to apply L2 normalization at each layer output.
        learning_rate (float): Learning rate for the optimizer.
        wd (float): Weight decay for the optimizer.
        loss (nn.Module): Loss function to use.
        train_mask (torch.Tensor): Mask for training data.
        val_mask (torch.Tensor): Mask for validation data.
        test_mask (torch.Tensor): Mask for test data.
        conv_type (str): Type of convolutional layer to use.
        accelerator (str): Device accelerator to use ('cpu' or 'cuda').
        in_aggr (str): Input aggregation method.
        out_aggr (str): Output aggregation method.
        lin_res (bool): Whether to use skip connections.
        ln (bool): Whether to use layer normalization.
        bn (bool): Whether to use batch normalization.
    """

    def __init__(
        self,
        edge_types,
        hidden_sizes,
        n_classes,
        num_layers,
        dropout=0,
        jumping_knowledge=None,
        normalize=False,
        learning_rate=0.003,
        wd=5e-2,
        loss=nn.NLLLoss(),
        train_mask=None,
        val_mask=None,
        test_mask=None,
        conv_type="SAGE",
        device="cpu",
        in_aggr="sum",
        out_aggr="sum",
        lin_res=False,
        ln=False,
        bn=False,
        input_dims=None,
        routing=False,
        k=None,
    ):

        super().__init__()

        self.loss = loss
        self.routing = routing
        self.weight_decay = wd
        self.val_mask = val_mask
        self.test_mask = test_mask
        self.train_mask = train_mask
        self.loss = loss.to(self.device)
        self.learning_rate = learning_rate
        self.noise_scheduler = NoiseScheduler()

        self.save_hyperparameters()

        # Define the common arguments for SSN and R-SSN
        common_args = {
            "edge_types": edge_types,
            "hidden_sizes": hidden_sizes,
            "n_classes": n_classes,
            "num_layers": num_layers,
            "dropout": dropout,
            "jumping_knowledge": jumping_knowledge,
            "normalize": normalize,
            "conv_type": conv_type,
            "device": device,
            "in_aggr": in_aggr,
            "out_aggr": out_aggr,
            "lin_res": lin_res,
            "bn": bn,
            "ln": ln,
        }

        # Define additional arguments for R-SSN
        routing_args = {
            "input_dims": input_dims,
            "k": k,
        }

        if routing:
            self.model = RelationRouter(**common_args, **routing_args)
        else:
            self.model = SSN(**common_args)

    def forward(self, x_dict, edge_index_dict, *args, **kwargs):
        """
        Forward pass.
        """
        return self.model(x_dict, edge_index_dict, *args, **kwargs)

    def compute_loss(self, out, y, mask):
        return self.loss(out[mask].squeeze(), y[mask].squeeze())

    def shared_step(self, batch, stage):
        """
        Shared logic for training, validation, and testing.
        """
        hetero_data, y = batch
        x_dict, edge_index_dict = hetero_data.x_dict, hetero_data.edge_index_dict
        y = y.to(self.device)
        mask = getattr(self, f"{stage}_mask").to(self.device)

        if self.routing:
            # Dynamically adjust noise level
            noise_epsilon = self.noise_scheduler.get_noise(self.current_epoch)
            out, moe_loss = self.model(x_dict, edge_index_dict, noise_epsilon)
            total_loss = self.compute_loss(out, y, mask) + moe_loss
        else:
            out = self(x_dict, edge_index_dict)
            total_loss = self.compute_loss(out, y, mask)

        acc = (
            compute_multiclass_accuracy(out[mask].squeeze(), y[mask].squeeze())
            .detach()
            .cpu()
            .item()
        )

        # ROC-AUC computation
        if y.max() <= 1:  # Ensure the dataset is binary for ROC-AUC
            probs = out[mask].softmax(dim=1)[:, 1].detach().cpu().numpy()
            roc_auc = roc_auc_score(y[mask].cpu().numpy(), probs)
        else:
            roc_auc = None  # Only applicable for binary classification

        # Log metrics
        if stage in ["val", "test"]:
            self.log(f"{stage}_loss", total_loss, prog_bar=True, batch_size=1)
            self.log(f"{stage}_acc", acc, prog_bar=True, batch_size=1)
            if roc_auc is not None:
                self.log(f"{stage}_roc_auc", roc_auc, prog_bar=True, batch_size=1)
        else:
            self.log(f"{stage}_loss", total_loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

        return total_loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )


class SSN(BaseSSN):
    """
    SSN adaptation for node classification tasks.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.direction_experts = self._build_direction_experts()

    def _build_direction_experts(self):
        """
        Define SSN layers on each adjacency
        """

        convs = nn.ModuleList()
        in_channels = (
            (-1, -1) if self.conv_type == "GAT" else -1
        )  # GAT requires in_channels to be a tuple

        for layer_idx in range(self.num_layers):

            layer_edge_types = [
                edge_type
                for edge_type in self.edge_types
                if int(edge_type[0]) <= layer_idx and int(edge_type[2]) <= layer_idx + 1
            ]

            conv = SSNHeteroConv(
                {
                    edge_type: self.ConvLayer(
                        in_channels=in_channels,
                        out_channels=self.hidden_sizes[edge_type[-1]],
                        **self.conv_kwargs,
                    )
                    for edge_type in layer_edge_types
                }
            )

            convs.append(conv)

        return convs

    def forward(self, x_dict, edge_index_dict):
        """Forward pass for the heterogeneous GNN.

        Parameters:
            x_dict: Node features for each node type.
            edge_index_dict: Edge indices for each edge type.

        Returns:
            Node logits for classification.
        """

        xs = []
        x_dict = deepcopy(x_dict)

        for layer_idx, conv in enumerate(self.direction_experts):

            out_dict = {}
            # Forward pass through the HeteroConv layers
            inter_dict = conv(x_dict, edge_index_dict)

            # Aggregate separately the outputs of convolutions on different
            # simplex levels
            for dst, rel_dict in inter_dict.items():
                for sr in rel_dict.keys():
                    # Ensure out_dict[dst] is initialized as a list
                    if dst not in out_dict:
                        out_dict[dst] = []
                    # Iterate over the range and append results to out_dict
                    out_dict[dst].append(group(rel_dict[sr], self.inner_aggregation))

            # Add skip connections
            for key, value in out_dict.items():
                aggregated = group(value, self.outer_aggregation)
                aggregated = (
                    aggregated + self.lins_res[key][layer_idx](x_dict[key])
                    if self.lins_res
                    else aggregated
                )

                x_dict[key] = aggregated

            x_dict = self._apply_layer_transforms(
                x_dict, self.dropout, self.training, layer_idx
            )

            if layer_idx >= self.mdim:
                xs.append(x_dict["0"].clone())

        if self.mdim > 1:
            x_dict = self.feature_projector(x_dict, edge_index_dict)
            xs.append(x_dict["0"].clone())

        x = self.jump(xs) if self.jumping_knowledge else x_dict["0"]
        x = self.lin_nodes(x)

        return F.log_softmax(x, dim=1)
