import pytest
import torch
import numpy as np
import scipy.sparse as sp

from datasets.rewiring.khop import edge_index_to_csr, k_hop_pairs_batch, k_hop_pairs_batch_sparse


def test_edge_index_to_csr_simple():
    edge_index = np.array([[0, 1, 2], [1, 2, 0]])  # 0→1, 1→2, 2→0
    A = edge_index_to_csr(edge_index, num_nodes=3)
    assert isinstance(A, sp.csr_matrix)
    assert A.shape == (3, 3)
    assert A[0, 1] == 1 and A[1, 2] == 1 and A[2, 0] == 1
    assert A[1, 0] == 0  # directionality check


def test_khop_1hop_directed():
    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=4, k=1)
    out_np = out.numpy()
    expected = np.array([[0, 1, 2], [1, 2, 3]])
    # Compare ignoring order
    assert set(map(tuple, out_np.T)) == set(map(tuple, expected.T))


def test_khop_2hop_directed_chain():
    # Linear chain 0→1→2→3
    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=4, k=2)
    pairs = set(map(tuple, out.numpy().T))
    expected = {(0, 1), (1, 2), (2, 3), (0, 2), (1, 3)}
    assert pairs == expected


def test_khop_undirected():
    # Undirected triangle 0–1–2
    edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=3, k=1, undirected=True)
    pairs = set(map(tuple, out.numpy().T))
    expected = {(0, 1), (1, 0), (1, 2), (2, 1), (0, 2), (2, 0)}
    assert pairs == expected


def test_khop_self_exclusion():
    # Ensure nodes are not paired with themselves
    edge_index = torch.tensor([[0, 1], [1, 0]])  # undirected
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=2, k=1, undirected=True)
    assert not any(out[0] == out[1])  # no (i, i) pairs


def test_khop_batching_equivalence():
    # Check that batching gives same result as full processing
    edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])
    res1 = k_hop_pairs_batch_sparse(edge_index, num_nodes=5, k=2, batch_size=5)
    res2 = k_hop_pairs_batch_sparse(edge_index, num_nodes=5, k=2, batch_size=2)
    assert set(map(tuple, res1.numpy().T)) == set(map(tuple, res2.numpy().T))


def test_khop_disconnected_graph():
    # Graph with two components: 0→1 and 2→3
    edge_index = torch.tensor([[0, 2], [1, 3]])
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=4, k=2)
    pairs = set(map(tuple, out.numpy().T))
    expected = {(0, 1), (2, 3)}
    assert pairs == expected


def test_khop_empty_graph():
    # No edges
    edge_index = torch.empty((2, 0), dtype=torch.long)
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=3, k=2)
    assert out.numel() == 0  # should produce empty tensor
    assert out.shape == (2, 0)


def test_khop_large_k_reaches_all():
    # Fully connected in 2 steps: 0→1→2
    edge_index = torch.tensor([[0, 1], [1, 2]])
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=3, k=5)
    pairs = set(map(tuple, out.numpy().T))
    expected = {(0, 1), (1, 2), (0, 2)}
    assert pairs == expected


def test_output_device_matches_input():
    edge_index = torch.tensor([[0, 1], [1, 2]], device="cpu")
    out = k_hop_pairs_batch_sparse(edge_index, num_nodes=3, k=1)
    assert out.device == edge_index.device
