from typing import Any, Literal, Optional

import torch
import torch.nn as nn
from pylibcugraphops.pytorch import operators

from ..base import BaseBackend, BaseConvolution
from ..registry import BackendRegistry

doc = """
Cugraph Backend: implementations using `pylibcugraph` library.
"""

class _SimpleAggrGraphConv(BaseConvolution):
    """
    Simple Graph convolutions with cugraph backend: everything which can be done with a single sparse aggregation
    (mean/max/GCN/etc.)
    """

    def __init__(
        self,
        bias: bool = True,
        use_edge_weights: bool = False,
        aggr_type: Literal["sum", "max", "min", "mean"] = "sum",
        **kwargs: Any,
    ) -> None:
        """Initialize a Torch-native GCN.

        Args:
            bias (bool): Include bias in linear transform.
            **kwargs (Any): Reserved for future options.
        """
        super().__init__(bias=bias, **kwargs)
        self.use_edge_weights = use_edge_weights
        self.aggr_type = aggr_type

    def forward(
        self,
        x: torch.Tensor,
        graph: Any,
        *,
        edge_weight: torch.Tensor | None = None,  # ignored for baseline
        **kwargs: Any,
    ) -> torch.Tensor:
        """Apply Graph Transformer layer

        Args:
            x (torch.Tensor): Node features [N, Fin].
            graph (Any): Either (edge_index, num_nodes) or (edge_index, edge_weight) or (edge_index, ew, num_nodes).
            edge_weight (Optional[torch.Tensor]): Unused baseline.
            **kwargs (Any): Extra kwargs ignored.

        Returns:
            torch.Tensor: Output features [N, Fout].
        """
        csc_graph, gcn_weights = graph
        weights = gcn_weights if self.use_edge_weights else None
        out = operators.agg_simple_n2n(
            feat=x,
            graph=csc_graph,
            aggr=self.aggr_type,
            edge_weight=weights,
        )

        return out


@BackendRegistry.register_backend("cugraph")
class CugraphBackend(BaseBackend):
    """Backend instantiating Cugraph-powered convolutions."""

    def create_conv(
        self,
        conv_type: str,
        **kwargs: Any,
    ):
        """Factory for Torch-native mean aggregation convs.

        Args:
            conv_type (str): supported convolution type.
            feature_dim (int): Input (and output) feature size.
            **kwargs (Any): Extra kwargs.

        Returns:
            BaseConvolution: Convolution layer for CuGraph backend
        """
        feature_dim = kwargs.pop("feature_dim")

        if conv_type == "mean_aggr":
            return _SimpleAggrGraphConv(aggr_type="mean", use_edge_weights=False)
        if conv_type == "sum_aggr":
            return _SimpleAggrGraphConv(aggr_type="sum", use_edge_weights=False)
        if conv_type == "min_aggr":
            return _SimpleAggrGraphConv(aggr_type="min", use_edge_weights=False)
        if conv_type == "max_aggr":
            return _SimpleAggrGraphConv(aggr_type="max", use_edge_weights=False)
        if conv_type == "gcn":
            return _SimpleAggrGraphConv(aggr_type="sum", use_edge_weights=True)
        if conv_type == "gat_v2":
            return _CugraphGATv2Conv(feature_dim)
        if conv_type == "gt":
            # raise NotImplementedError("mha_simple_n2n is broken and doesn't work with correct inputs")
            heads = kwargs.pop("heads")
            return _CugraphGraphTransfomerConv(feature_dim, heads, **kwargs)
