import torch
from torch import nn
from copy import deepcopy
from torch.nn import Linear
import torch.nn.functional as F
from src.models.base_ssn import BaseSSN
from src.models.rel_routing import RelationRouter
from pytorch_lightning import LightningModule
from torch_geometric.nn import global_mean_pool
from src.models.ssn_heteroconv import SSNHeteroConv, group
from utils.utils import compute_multiclass_accuracy
from src.models.noise_scheduler import NoiseScheduler


class DynamicsClassifier(LightningModule):
    def __init__(
        self,
        edge_types,
        hidden_sizes,
        n_classes,
        num_layers,
        dropout=0,
        jumping_knowledge=None,
        normalize=False,
        learning_rate=0.003,
        wd=5e-2,
        loss=nn.NLLLoss(),
        conv_type="SAGE",
        device="gpu",
        in_aggr="sum",
        out_aggr="sum",
        lin_res=False,
        ln=False,
        bn=False,
        input_dims=None,
        routing=False,
        k=None,
    ):
        """
        Initialize the DynamicsClassifier.

        Args:
            edge_types: List of edge types.
            hidden_sizes: Hidden sizes used in the model.
            n_classes: Number of output classes.
            num_layers: Number of layers in the network.
            dropout: Dropout probability.
            jumping_knowledge: Whether to use jumping knowledge.
            normalize: If True, normalize the input features.
            learning_rate: Learning rate for the optimizer.
            wd: Weight decay (L2 regularization).
            loss: Loss function.
            conv_type: Type of convolution (e.g. "SAGE").
            device: Accelerator to use (e.g. "gpu").
            in_aggr: Aggregation method for incoming messages.
            out_aggr: Aggregation method for outgoing messages.
            lin_res: Whether to use a linear residual connection.
            ln: Whether to use layer normalization.
            bn: Whether to use batch normalization.
            input_dims: Input dimensions (only used if routing is enabled).
            routing: If True, use the routing (DirRSNN) variant.
            k: Additional parameter for routing.
            task: Type of task (default "classification").
        """

        super().__init__()
        self.save_hyperparameters(
            ignore=["loss"]
        )  # Save hyperparameters for reproducibility

        # Do not call .to(self.device) here; Lightning handles device placement.
        self.loss = loss
        self.routing = routing
        self.weight_decay = wd
        self.learning_rate = learning_rate
        self.noise_scheduler = NoiseScheduler()

        common_args = {
            "edge_types": edge_types,
            "hidden_sizes": hidden_sizes,
            "n_classes": n_classes,
            "num_layers": num_layers,
            "dropout": dropout,
            "jumping_knowledge": jumping_knowledge,
            "normalize": normalize,
            "conv_type": conv_type,
            "device": device,
            "in_aggr": in_aggr,
            "out_aggr": out_aggr,
            "lin_res": lin_res,
            "bn": bn,
            "ln": ln,
        }
        routing_args = {"input_dims": input_dims, "k": k}

        # Instantiate the correct model variant.
        if routing:
            self.model = RSSN(**common_args, **routing_args)
        else:
            self.model = SSN(**common_args)

    def forward(self, x_dict, edge_index_dict, extra_arg):
        """
        Forward pass.

        The extra argument is used differently depending on whether routing is enabled:
          - If routing is True, extra_arg is interpreted as noise_epsilon.
          - If routing is False, extra_arg is interpreted as the batch_dict.
        """
        return self.model(x_dict, edge_index_dict, extra_arg)

    def compute_loss(self, out, y):
        """
        Compute the loss between model output and target.
        """
        return self.loss(out.squeeze(), y.squeeze())

    def shared_step(self, batch, stage):
        """
        Common logic for training, validation, and testing.

        Args:
            batch: A batch of data. Expected to contain:
                - hetero_data: an object with attributes `x_dict`, `edge_index_dict`, and `y`.
                - (optionally) for each key in x_dict, a node batch indicator is expected as hetero_data[key].batch.
            stage: One of "train", "val", or "test".

        Returns:
            The computed loss for the batch.
        """

        x_dict, edge_index_dict, y = batch

        # If routing is enabled, adjust the forward call accordingly.
        if self.routing:
            noise_epsilon = self.noise_scheduler.get_noise(self.current_epoch)
            out, moe_loss = self(x_dict, edge_index_dict, noise_epsilon)
            loss_val = self.compute_loss(out, y) + moe_loss
        else:
            out = self(x_dict, edge_index_dict, None)
            loss_val = self.compute_loss(out, y)

        acc = compute_multiclass_accuracy(out, y).detach().cpu().item()

        # Prepare logging parameters.
        log_kwargs = (
            {"prog_bar": True}
            if stage == "train"
            else {"prog_bar": True, "batch_size": 1}
        )
        self.log(f"{stage}_loss", loss_val, **log_kwargs)
        self.log(f"{stage}_acc", acc, **log_kwargs)

        return loss_val

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )


class SSN(BaseSSN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.direction_experts = self._build_direction_experts()
        self.linear = nn.Linear(sum(self.hidden_sizes.values()), self.num_classes)

    def _build_direction_experts(self):

        convs = nn.ModuleList()

        in_channels = (-1, -1) if self.conv_type == "GAT" else -1

        for layer_idx in range(self.num_layers):
            conv = SSNHeteroConv(
                {
                    edge_type: self.ConvLayer(
                        in_channels=in_channels,
                        out_channels=self.hidden_sizes[edge_type[-1]],
                        **self.conv_kwargs,
                    )
                    for edge_type in self.edge_types
                }
            )

            convs.append(conv)

        return convs

    def forward(self, x_dict, edge_index_dict, batch_dict):
        """Forward pass for the heterogeneous GNN.

        Parameters:
            x_dict: Node features for each node type.
            edge_index_dict: Edge indices for each edge type.

        Returns:
            Node logits for classification.
        """

        x_old = x_dict
        x_dict = deepcopy(x_old)
        del x_old

        for layer_idx, conv in enumerate(self.direction_experts):

            out_dict = {}
            inter_dict = conv(x_dict, edge_index_dict)

            for dst, rel_dict in inter_dict.items():
                for sr in rel_dict.keys():
                    # Ensure out_dict[dst] is initialized as a list
                    if dst not in out_dict:
                        out_dict[dst] = []
                    # Iterate over the range and append results to out_dict
                    out_dict[dst].append(group(rel_dict[sr], self.inner_aggregation))

            for key, value in out_dict.items():
                aggregated = group(value, self.outer_aggregation)
                aggregated = (
                    aggregated + self.lins_res[key][layer_idx](x_dict[key])
                    if self.lins_res
                    else aggregated
                )

                x_dict[key] = aggregated

            x_dict = self._apply_layer_transforms(
                x_dict, self.dropout, self.training, layer_idx
            )

        out_dict = {k: x_dict[k].mean(1) for k in x_dict.keys()}

        # Concatenate average embeddings of different simplices
        # put zeros if there are no simplices for that level
        x = torch.cat(
            [
                (
                    out_dict[k]
                    if out_dict[k].shape[1] == self.hidden_sizes[k]
                    else torch.zeros([out_dict[k].shape[0], self.hidden_sizes[k]]).to(
                        self.device
                    )
                )
                for k in x_dict.keys()
            ],
            dim=1,
        )

        x = self.linear(x)
        x = F.log_softmax(x, dim=1)

        return x


class RSSN(BaseSSN):
    def __init__(self, *args, **kwargs) -> None:
        self.k = kwargs.pop("k", None)
        self.input_dims = kwargs.pop("input_dims", None)
        super().__init__(*args, **kwargs)
        self.linear = Linear(sum(self.hidden_sizes.values()), self.num_classes)
        self.topo_rels = self._initialize_topological_relations()
        self.direction_experts = self._build_direction_experts()

    def _initialize_topological_relations(self):
        """Processes edge types to extract topological relations."""
        topo_rels = {}
        for src, rel, dst in self.edge_types:
            rel_parts = rel.split("_")
            key = f"{src}_{rel_parts[0]}"
            topo_rels.setdefault(key, []).append(
                (src, f"{rel_parts[0]}_{rel_parts[1]}", dst)
            )
        return topo_rels

    def _build_direction_experts(self):
        """Initializes directed experts with Mixture of Experts or HeteroConv layers."""
        directed_experts = nn.ModuleList()
        for layer_idx in range(self.num_layers):
            layer_dict = {}
            for adj, adjs in self.topo_rels.items():
                src = adj.split("_")[0]
                if len(adjs) > 1:
                    module = RelationRouter(
                        experts=nn.ModuleList(
                            [
                                SSNHeteroConv(
                                    {
                                        edge_type: self.ConvLayer(
                                            in_channels=-1,
                                            out_channels=self.hidden_sizes[
                                                edge_type[-1]
                                            ],
                                            **self.conv_kwargs,
                                        )
                                    },
                                    unique=True,
                                ).to(self.device)
                                for edge_type in self.topo_rels[adj]
                            ]
                        ),
                        input_size=(
                            self.input_dims[src]
                            if layer_idx == 0
                            else self.hidden_sizes[src]
                        ),
                        output_size=self.hidden_sizes[src],
                        accelerator=self.device,
                        k=self.k,
                    )

                    layer_dict[str(adj)] = module

                else:
                    # Use HeteroConv for Single Relation
                    single_edge_type = adjs[0]
                    module = SSNHeteroConv(
                        {
                            single_edge_type: self.ConvLayer(
                                in_channels=-1,
                                out_channels=self.hidden_sizes[single_edge_type[-1]],
                                **self.conv_kwargs,
                            )
                        },
                        unique=True,
                    ).to(self.device)
                    layer_dict[str(adj)] = module

            directed_experts.append(nn.ModuleDict(layer_dict))

        return directed_experts

    def forward(self, x_dict, edge_index_dict, batch_dict, noise_epsilon=1e-2):
        """Forward pass for Direction Routing SNN."""

        moe_losses = []

        x_old = x_dict
        x_dict = deepcopy(x_old)
        del x_old
        # x_dict = deepcopy(x_dict)

        for layer_idx, layers in enumerate(self.direction_experts):
            out_dict = {}
            for adj, layer in layers.items():
                dst = self.topo_rels[adj][0][-1]
                if isinstance(layer, RelationRouter):
                    y_adj, loss = layer(
                        x_dict,
                        edge_index_dict,
                        adj.split("_")[0],
                        self.training,
                        layer_idx,
                        noise_epsilon,
                    )
                    moe_losses.append(loss)
                else:
                    y_adj = layer(x_dict, edge_index_dict)

                out_dict.setdefault(dst, []).append(y_adj)

            for key, value in out_dict.items():
                aggregated = group(value, self.outer_aggregation)
                aggregated = (
                    aggregated + self.lins_res[key][layer_idx](x_dict[key])
                    if self.lins_res
                    else aggregated
                )
                x_dict[key] = aggregated

            x_dict = self._apply_layer_transforms(
                x_dict, self.dropout, self.training, layer_idx
            )

        out_dict = {
            k: global_mean_pool(x_dict[k], batch_dict[k]) for k in x_dict.keys()
        }

        x = torch.cat(
            [
                (
                    out_dict[k]
                    if out_dict[k].shape[1] == self.hidden_sizes[k]
                    else torch.zeros([out_dict[k].shape[0], self.hidden_sizes[k]]).to(
                        self.device
                    )
                )
                for k in x_dict.keys()
            ],
            dim=1,
        )
        x = self.linear(x)
        x = F.log_softmax(x, dim=1)

        moe_loss = sum(moe_losses) / len(moe_losses) if moe_losses else 0

        return x, moe_loss
