import sys
from pathlib import Path

import pytest
import torch
from fixtures import (
    create_conv_layer,
    create_graph_sample,
    device,
    karate_like_club_graph,
    set_default_device,
    small_graph_data,
)

from src.backends.cugraph_backend import CugraphBackend
from src.backends.registry import BackendRegistry

try:
    from pylibcugraphops.pytorch import CSC, operators

    HAS_CUGRAPH = True
except ImportError:
    HAS_CUGRAPH = False

pytestmark = pytest.mark.skipif(not HAS_CUGRAPH, reason="cugraph not installed")


class TestCugraphBasicAggregation:
    """Test basic aggregation operations (sum/mean/min/max)."""

    @pytest.mark.parametrize("aggr_type", ["sum", "mean", "min", "max"])
    def test_mean_aggregation_matches_dgl(
        self, aggr_type, karate_like_club_graph, create_graph_sample, create_conv_layer
    ):
        """Test that cugraph mean aggregation matches DGL's copy_u_mean."""
        try:
            import dgl
            import dgl.ops as dgl_ops
        except ImportError:
            pytest.skip("DGL not installed - cannot verify correctness")

        data = karate_like_club_graph
        features = data["features"]

        dgl_graph = dgl.graph((data["edge_index"][0], data["edge_index"][1]), num_nodes=data["num_nodes"]).to(
            data["device"]
        )
        dgl_graph = dgl.add_self_loop(dgl_graph)

        graph_sample = create_graph_sample(
            edge_index=data["edge_index"],
            features=features,
            backend="cugraph",
            num_nodes=data["num_nodes"],
        )

        match aggr_type:
            case "sum":
                dgl_op = dgl_ops.copy_u_sum
            case "mean":
                dgl_op = dgl_ops.copy_u_mean
            case "min":
                dgl_op = dgl_ops.copy_u_min
            case "max":
                dgl_op = dgl_ops.copy_u_max

        dgl_output = dgl_op(dgl_graph, features)
        conv = create_conv_layer("cugraph", f"{aggr_type}_aggr", feature_dim=data["in_channels"], bias=False)
        cugraph_output = conv(features, graph_sample.graph_repr)

        assert torch.allclose(
            cugraph_output, dgl_output, atol=1e-6, rtol=1e-5
        ), f"CuGraph {aggr_type} aggregation doesn't match DGL"


class TestCugraphGCN:
    """Test GCN (sum aggregation with edge weights)."""

    def test_gcn_basic(self, karate_like_club_graph, create_graph_sample, create_conv_layer):
        """Test GCN forward and backward."""
        data = karate_like_club_graph
        features = data["features"].clone().requires_grad_(True)

        graph_sample = create_graph_sample(
            edge_index=data["edge_index"],
            features=features,
            backend="cugraph",
            num_nodes=data["num_nodes"],
        )

        conv = create_conv_layer("cugraph", "gcn", feature_dim=data["in_channels"], bias=False)

        output = conv(features, graph_sample.graph_repr)
        assert output.shape == (data["num_nodes"], data["in_channels"])
        assert not torch.isnan(output).any()

        loss = output.sum()
        loss.backward()
        assert features.grad is not None
        assert not torch.isnan(features.grad).any()


if __name__ == "__main__":
    pytest.main([__file__, "-v", "--tb=short"])
