from torch import nn
from pytorch_lightning import LightningModule
import torch
from torch import nn
from copy import deepcopy
from src.models.ssn_heteroconv import group, SSNHeteroConv
from src.models.base_ssn import BaseSSN


class EdgeRegressor(LightningModule):
    """
        Lightning wrapper class for SSN for edge regression tasks.
        Parameters:
            edge_types (list): List of edge types in the heterogeneous graph.
            hidden_sizes (dict): Dictionary of hidden layer sizes for each layer.
            n_classes (int): Number of output classes.
            num_layers (int): Number of layers in the model.
            dropout (float): Dropout rate.
            jumping_knowledge (str): Type of jumping knowledge to use (None for no JK).
            normalize (bool): Whether to apply L2 normalization at each layer output.
            learning_rate (float): Learning rate for the optimizer.
            wd (float): Weight decay for the optimizer.
            loss (nn.Module): Loss function to use.
            train_mask (torch.Tensor): Mask for training data.
            val_mask (torch.Tensor): Mask for validation data.
            test_mask (torch.Tensor): Mask for test data.
            conv_type (str): Type of convolutional layer to use.
            device (str): Device accelerator to use ('cpu' or 'cuda').
            in_aggr (str): Input aggregation method.
            out_aggr (str): Output aggregation method.
            lin_res (bool): Whether to use skip connections.
            directed_masks (tuple of torch.tensor): masks for (starting undirected edges, ending undirected edges, directed edges)
            ln (bool): Whether to use layer normalization.
            bn (bool): Whether to use batch normalization.
            loss_alpha (float): weight for the loss on undirected and directed edges.
            routing (bool): not implemented, here for compatibility
        """
    def __init__(
        self,
        edge_types,
        hidden_sizes,
        n_classes,
        num_layers,
        dropout=0,
        jumping_knowledge=None,
        normalize=False,
        learning_rate=0.003,
        wd=5e-2,
        loss=nn.NLLLoss(),
        train_mask=None,
        val_mask=None,
        test_mask=None,
        conv_type="SAGE",
        device="gpu",
        in_aggr="sum",
        out_aggr="sum",
        lin_res=False,
        directed_masks=None,
        ln=False,
        bn=False,
        loss_alpha=0.0,
        routing=False,
    ):
        super().__init__()

        self.loss = loss.to(self.device)
        self.directed_masks = directed_masks
        self.weight_decay = wd
        self.loss = loss.to(self.device)
        self.learning_rate = learning_rate
        self.val_mask = val_mask.to(self.device)
        self.test_mask = test_mask.to(self.device)
        self.train_mask = train_mask.to(self.device)
        self.loss_alpha = loss_alpha

        self.save_hyperparameters()

        # Define the common arguments for both models
        common_args = {
            "edge_types": edge_types,
            "hidden_sizes": hidden_sizes,
            "n_classes": n_classes,
            "num_layers": num_layers,
            "dropout": dropout,
            "jumping_knowledge": jumping_knowledge,
            "normalize": normalize,
            "conv_type": conv_type,
            "device": device,
            "in_aggr": in_aggr,
            "out_aggr": out_aggr,
            "lin_res": lin_res,
            "bn": bn,
            "ln": ln,
        }

        self.model = SSN(**common_args)

    def forward(self, x_dict, edge_index_dict):
        return self.model(x_dict, edge_index_dict)

    def prepare_data(self):
        pass

    def compute_loss(self, out, y, mask, train=False):
        """
        Compute the loss for model training and evaluation.
        Parameters:
            out (torch.tensor): Model predictions.
            y (torch.tensor): Ground truth labels.
            mask (torch.tensor): Mask of the current split (train, val, test).
            train (bool): whether the split is for training or not.
        Returns:
            lossValue: Computed loss value.
        """

        undirected_mask_1, undirected_mask_2, directed_mask = self.directed_masks
        merged_undirected_mask_1 = mask[undirected_mask_1]
        merged_undirected_mask_2 = mask[undirected_mask_2]
        if (
            not train
        ):  # during test, only directed edges should be considered in this term of the loss
            merged_directed_mask = torch.logical_and(directed_mask, mask)
        else:  # during training, consider all edges as directed
            merged_directed_mask = mask

        # Compute difference of predictions and labels for undirected edges
        out_diff = (
            out[undirected_mask_1][merged_undirected_mask_1].squeeze()
            - out[undirected_mask_2][merged_undirected_mask_2].squeeze()
        ).abs()
        y_diff = (
            y[undirected_mask_1][merged_undirected_mask_1].squeeze()
            - y[undirected_mask_2][merged_undirected_mask_2].squeeze()
        ).abs()
        # Concatenate predictions and labels for the complete split
        all_out = torch.cat([out_diff, out[merged_directed_mask].squeeze()])
        all_y = torch.cat([y_diff, y[merged_directed_mask].squeeze()])
        if (
            train
        ):  # during training, balance between loss on directed edges and loss on difference of undirected edges
            lossValue = (self.loss_alpha) * self.loss(
                out[merged_directed_mask].squeeze(), y[merged_directed_mask].squeeze()
            ) + (1 - self.loss_alpha) * self.loss(out_diff, y_diff)
        else:  # during test, loss is the prediction difference on undirected edges and the prediction on directed edges
            lossValue = self.loss(all_out, all_y)

        return lossValue

    def training_step(self, batch, batch_idx):

        hetero_data, y = batch
        x_dict, edge_index_dict = hetero_data.x_dict, hetero_data.edge_index_dict

        out = self.model(x_dict, edge_index_dict)
        y = y.to(self.device)
        lossValueTrain = self.compute_loss(out, y, self.train_mask, train=True)
        self.log("train_loss", lossValueTrain, prog_bar=True)

        return lossValueTrain

    def validation_step(self, batch, batch_idx):
        hetero_data, y = batch
        x_dict, edge_index_dict = hetero_data.x_dict, hetero_data.edge_index_dict
        out = self.model(x_dict, edge_index_dict)

        val_loss = self.compute_loss(out, y, self.val_mask)
        self.log("val_loss", val_loss, prog_bar=True, batch_size=1)

        return val_loss

    def test_step(self, batch, batch_idx):

        hetero_data, y = batch
        x_dict, edge_index_dict = hetero_data.x_dict, hetero_data.edge_index_dict

        out = self.model(x_dict, edge_index_dict)

        y = y.to(self.device)
        test_loss = self.compute_loss(out, y, self.test_mask)
        self.log("test_loss", test_loss, batch_size=1)

        return test_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )
        return optimizer


class SSN(BaseSSN):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.direction_experts = self._build_direction_experts()
        # Define the readout layer
        input_dim = self.hidden_sizes["1"]
        self.lin_edges = torch.nn.Linear(input_dim, self.num_classes)

    def _build_direction_experts(self):
        """
        Define SSN layers on each adjacency
        """
        convs = nn.ModuleList()
        in_channels = (-1, -1) if self.conv_type == "GAT" else -1

        for layer_idx in range(self.num_layers):

            layer_edge_types = [edge_type for edge_type in self.edge_types]

            # Instantiate one HeteroConv object for each adjacency
            conv = SSNHeteroConv(
                {
                    edge_type: self.ConvLayer(
                        in_channels=in_channels,
                        out_channels=self.hidden_sizes[edge_type[-1]],
                        **self.conv_kwargs,
                    )
                    for edge_type in layer_edge_types
                }
            )

            convs.append(conv)

        return convs

    def forward(self, x_dict, edge_index_dict):
        """Forward pass for the heterogeneous GNN.

        Parameters:
            x_dict: Node features for each node type.
            edge_index_dict: Edge indices for each edge type.

        Returns:
            Node logits for classification.
        """

        xs = []
        x_dict = deepcopy(x_dict)

        for layer_idx, conv in enumerate(self.direction_experts):

            out_dict = {}
            # Forward pass through the HeteroConv layers
            inter_dict = conv(x_dict, edge_index_dict)

            # Aggregate separately the outputs of convolutions on different
            # simplex levels
            for dst, rel_dict in inter_dict.items():
                for sr in rel_dict.keys():
                    # Ensure out_dict[dst] is initialized as a list
                    if dst not in out_dict:
                        out_dict[dst] = []
                    # Iterate over the range and append results to out_dict
                    out_dict[dst].append(group(rel_dict[sr], self.inner_aggregation))

            # Add skip connections
            for key, value in out_dict.items():
                aggregated = group(value, self.outer_aggregation)
                aggregated = (
                    aggregated + self.lins_res[key][layer_idx](x_dict[key])
                    if self.lins_res
                    else aggregated
                )

                x_dict[key] = aggregated

            x_dict = self._apply_layer_transforms(
                x_dict, self.dropout, self.training, layer_idx
            )

            xs.append(x_dict["1"].clone())

        x = x_dict["1"]
        x = self.lin_edges(x)

        return x
