import torch
from torch_geometric.data import Data
import pytest

from src.evaluation.smoothing_measures.node_similarity import wu_smoothness


def test_wu_smoothness_chunking_consistency():
    """
    Test that wu_smoothness gives the same result
    with and without chunking.
    """

    # Create a larger random embedding to trigger chunking
    torch.manual_seed(0)
    N = 500  # number of nodes
    D = 16   # embedding dimension

    # Random symmetric edge_index (not really used in current implementation)
    edge_index = torch.randint(0, N, (2, 2 * N))
    x = torch.randn(N, D)

    data = Data(x=x, edge_index=edge_index)

    # Import the function from your module if necessary
    # from src.metrics import wu_smoothness

    # Compute smoothness without chunking (large chunk_size)
    result_no_chunk = wu_smoothness(data, embeddings=x, chunk_size=10**9)
    # Compute smoothness with chunking (small chunk_size)
    result_chunked = wu_smoothness(data, embeddings=x, chunk_size=4)

    # Check that both return a dict with the same key
    assert set(result_no_chunk.keys()) == set(result_chunked.keys())

    # Compare numerical closeness
    for key in result_no_chunk:
        assert torch.isclose(
            torch.tensor(result_no_chunk[key]),
            torch.tensor(result_chunked[key]),
            atol=1e-5,
            rtol=1e-5
        ), f"Mismatch in {key}: {result_no_chunk[key]} vs {result_chunked[key]}"

    print("✅ wu_smoothness chunking test passed.")