import pytest
import torch
import numpy as np
from torch_geometric.data import Data, Batch

# Import the function under test.
# Adjust the import path if your function resides in another module.
from src.evaluation.eval_measurements import mad_madgap_batch  # <<< replace 'your_module' with the actual module name


# ---------------------------
# Helper utilities for tests
# ---------------------------

def build_data(num_nodes: int, edges: list):
    """
    Build a torch_geometric.data.Data object.
    `edges` is a list of (u, v) undirected edges using 0-based local indices.
    """
    if len(edges) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        # Duplicate edges to make undirected (both directions)
        src = []
        dst = []
        for u, v in edges:
            src.append(u); dst.append(v)
            src.append(v); dst.append(u)
        edge_index = torch.tensor([src, dst], dtype=torch.long)
    return Data(edge_index=edge_index, num_nodes=num_nodes)


def bfs_all_distances_py(adj):
    """Return list of lists of hop distances (n x n), -1 = unreachable."""
    from collections import deque
    n = len(adj)
    dists = [[-1]*n for _ in range(n)]
    for s in range(n):
        q = deque([s])
        d = [-1]*n
        d[s] = 0
        while q:
            u = q.popleft()
            for v in adj[u]:
                if d[v] == -1:
                    d[v] = d[u] + 1
                    q.append(v)
        dists[s] = d
    return dists


def pairwise_cosine_distances_np(h):
    """h: (n, d) numpy array. Returns (n,n) matrix of cosine distances (1 - cosine)."""
    # normalize rows
    norms = np.linalg.norm(h, axis=1, keepdims=True).clip(1e-12)
    h_norm = h / norms
    sim = h_norm @ h_norm.T
    sim = np.clip(sim, -1.0, 1.0)
    return 1.0 - sim


def mad_reference(batch: Batch, embeddings: torch.Tensor,
                  max_neigh_hop: int = 3, min_remote_hop: int = 8):
    """
    Reference implementation (slow, clear) that mirrors the paper definition.
    Returns dict with keys 'MAD','MAD_neb','MAD_rmt','MADGap','per_graph'.
    """
    device = embeddings.device
    batch_vec = batch.batch.cpu().numpy()
    edge_index = batch.edge_index.cpu().numpy()
    graph_ids = np.unique(batch_vec).tolist()

    per_graph = []
    for gid in graph_ids:
        node_idx = np.where(batch_vec == gid)[0]
        n_g = len(node_idx)
        if n_g <= 1:
            per_graph.append((None, None, None, n_g))
            continue

        # Build local adjacency list (0..n_g-1)
        global_to_local = {g: i for i, g in enumerate(node_idx.tolist())}
        adj = [[] for _ in range(n_g)]
        for u, v in edge_index.T:
            if u in global_to_local and v in global_to_local:
                iu = global_to_local[u]; iv = global_to_local[v]
                if iv not in adj[iu]:
                    adj[iu].append(iv)
                if iu not in adj[iv]:
                    adj[iv].append(iu)

        # compute all-pair BFS distances
        dists = bfs_all_distances_py(adj)  # list of lists

        # get embedding block as numpy
        h_block = embeddings[node_idx].cpu().numpy()
        dist_mat = pairwise_cosine_distances_np(h_block)

        # per-node aggregates
        per_node_global_avg = []
        per_node_neb_avg = []
        per_node_rmt_avg = []

        for i in range(n_g):
            # reachable nodes excluding self
            reach_idxs = [j for j, dd in enumerate(dists[i]) if dd != -1 and j != i]
            if len(reach_idxs) > 0:
                s = dist_mat[i, reach_idxs].mean()
                per_node_global_avg.append(s)

            # neighbor mask: 1 <= hop <= max_neigh_hop
            neb_idxs = [j for j, dd in enumerate(dists[i]) if dd != -1 and 1 <= dd <= max_neigh_hop]
            if len(neb_idxs) > 0:
                per_node_neb_avg.append(dist_mat[i, neb_idxs].mean())

            # remote mask: hop >= min_remote_hop
            rmt_idxs = [j for j, dd in enumerate(dists[i]) if dd != -1 and dd >= min_remote_hop]
            if len(rmt_idxs) > 0:
                per_node_rmt_avg.append(dist_mat[i, rmt_idxs].mean())

        mad_global = float(np.mean(per_node_global_avg)) if len(per_node_global_avg) > 0 else None
        mad_neb = float(np.mean(per_node_neb_avg)) if len(per_node_neb_avg) > 0 else None
        mad_rmt = float(np.mean(per_node_rmt_avg)) if len(per_node_rmt_avg) > 0 else None

        per_graph.append((mad_global, mad_neb, mad_rmt, n_g))

    # aggregate across graphs (mean of per-graph values that are not None)
    def mean_list(vals):
        vals = [v for v in vals if v is not None]
        return float(np.mean(vals)) if len(vals) > 0 else None

    mad_vals = [pg[0] for pg in per_graph]
    mad_neb_vals = [pg[1] for pg in per_graph]
    mad_rmt_vals = [pg[2] for pg in per_graph]
    mad_mean = mean_list(mad_vals)
    mad_neb_mean = mean_list(mad_neb_vals)
    mad_rmt_mean = mean_list(mad_rmt_vals)
    madgap = (mad_rmt_mean - mad_neb_mean) if (mad_rmt_mean is not None and mad_neb_mean is not None) else None

    return {
        "MAD": mad_mean,
        "MAD_neb": mad_neb_mean,
        "MAD_rmt": mad_rmt_mean,
        "MADGap": madgap,
        "per_graph": per_graph,
    }

def assert_results_close(res, ref, rel_tol=1e-6, abs_tol=1e-6):
    """Compare dictionaries; allow None equivalence and float approx."""
    for k in ("MAD", "MAD_neb", "MAD_rmt", "MADGap"):
        if res[k] is None or ref[k] is None:
            assert res[k] is None and ref[k] is None, f"Mismatch for {k}: {res[k]} vs {ref[k]}"
        else:
            raise AssertionError(
                f"Mismatch in {k}: expected {ref[k]:.6f}, "
                f"got {res[k]:.6f}"
            )


def test_two_node_identical_embeddings():
    # Graph with 2 nodes connected; embeddings identical -> cosine distance 0
    data = build_data(2, edges=[(0, 1)])
    batch = Batch.from_data_list([data])
    embeddings = torch.tensor([[1.0, 0.0], [1.0, 0.0]], dtype=torch.float32)  # identical
    res = mad_madgap_batch(batch, embeddings)
    assert res["MAD"] == pytest.approx(0.0)
    assert res["MAD_neb"] == pytest.approx(0.0)
    assert res["MAD_rmt"] is None
    assert res["MADGap"] is None

def test_two_node_opposite_embeddings():
    # Graph with 2 nodes connected; embeddings identical -> cosine distance 0
    data = build_data(2, edges=[(0, 1)])
    batch = Batch.from_data_list([data])
    embeddings = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)  # identical
    res = mad_madgap_batch(batch, embeddings)
    assert res["MAD"] == pytest.approx(1.0)
    assert res["MAD_neb"] == pytest.approx(1.0)
    assert res["MAD_rmt"] is None
    assert res["MADGap"] is None


def test_single_node_graph():
    # Single node graph -> all None
    data = build_data(1, edges=[])
    batch = Batch.from_data_list([data])
    embeddings = torch.randn((1, 4), dtype=torch.float32)
    res = mad_madgap_batch(batch, embeddings)
    assert res["MAD"] is None
    assert res["MAD_neb"] is None
    assert res["MAD_rmt"] is None
    assert res["MADGap"] is None


def test_chain_graph_distance_masks():
    # Path graph with 9 nodes (0-8). This ensures nodes with hop >=8 exist (0 and 8)
    n = 9
    edges = [(i, i+1) for i in range(n-1)]
    data = build_data(n, edges=edges)
    batch = Batch.from_data_list([data])
    # create orthogonal-ish embeddings so distances are non-trivial:
    # use basis vectors (dimension n)
    emb = np.eye(n, dtype=np.float32)
    embeddings = torch.from_numpy(emb)
    # run both functions
    res = mad_madgap_batch(batch, embeddings)
    assert res["MAD"] == pytest.approx(1.0)
    assert res["MAD_neb"] == pytest.approx(1.0)
    assert res["MAD_rmt"] == pytest.approx(1.0)
    assert res["MADGap"] == pytest.approx(0.0)


def test_disconnected_components():
    # Graph with two components: A triangle (0-1-2) and a chain length 4 (3-6).
    # This checks reachable semantics (no path across components)
    edges = [(0,1), (1,2), (3,4), (4,5), (5,6)]
    data = build_data(7, edges)
    batch = Batch.from_data_list([data])
    # embeddings random but deterministic
    torch.manual_seed(0)
    embeddings = torch.randn((7, 8), dtype=torch.float32)
    res = mad_madgap_batch(batch, embeddings)
    ref = mad_reference(batch, embeddings)
    assert res["MAD"] == pytest.approx(ref["MAD"])
    assert res["MAD_neb"] == pytest.approx(ref["MAD_neb"])
    assert res["MAD_rmt"] == pytest.approx(ref["MAD_rmt"])
    assert res["MADGap"] == pytest.approx(ref["MADGap"])


def test_batch_mixed_graphs_and_counts():
    # Mix several small graphs in a batch: single node, two-node, and a 5-node path
    data1 = build_data(1, [])
    data2 = build_data(2, [(0,1)])
    data3 = build_data(5, [(0,1),(1,2),(2,3),(3,4)])
    batch = Batch.from_data_list([data1, data2, data3])
    # embeddings: stack blocks for each graph
    emb1 = np.random.RandomState(1).randn(1,4).astype(np.float32)
    emb2 = np.random.RandomState(2).randn(2,4).astype(np.float32)
    emb3 = np.random.RandomState(3).randn(5,4).astype(np.float32)
    embeddings = torch.from_numpy(np.vstack([emb1, emb2, emb3]))
    res = mad_madgap_batch(batch, embeddings)
    ref = mad_reference(batch, embeddings)

    assert res["MAD"] == pytest.approx(ref["MAD"])
    assert res["MAD_neb"] == pytest.approx(ref["MAD_neb"])
    assert res["MAD_rmt"] == pytest.approx(ref["MAD_rmt"])
    assert res["MADGap"] == pytest.approx(ref["MADGap"])



def test_chunking_path_matches_dense():
    # Force the chunked path by setting dense_memory_threshold tiny (1)
    # Use a graph with moderate size (6 nodes) to exercise both paths.
    n = 6
    edges = [(i, i+1) for i in range(n-1)]
    data = build_data(n, edges)
    batch = Batch.from_data_list([data])
    # use random embeddings
    torch.manual_seed(42)
    embeddings = torch.randn((n, 5), dtype=torch.float32)
    # run dense (large threshold)
    res_dense = mad_madgap_batch(batch, embeddings, dense_memory_threshold=10_000_000)
    # run chunked (tiny threshold to force chunking)
    res_chunk = mad_madgap_batch(batch, embeddings, dense_memory_threshold=1)
    # compare
    assert res_dense["MAD"] == pytest.approx(res_chunk["MAD"])
    assert res_dense["MAD_neb"] == pytest.approx(res_chunk["MAD_neb"])
    assert res_dense["MAD_rmt"] == pytest.approx(res_chunk["MAD_rmt"])
    assert res_dense["MADGap"] == pytest.approx(res_chunk["MADGap"])

    # both must match reference
    ref = mad_reference(batch, embeddings)
    assert res_dense["MAD"] == pytest.approx(ref["MAD"])
    assert res_dense["MAD_neb"] == pytest.approx(ref["MAD_neb"])
    assert res_dense["MAD_rmt"] == pytest.approx(ref["MAD_rmt"])
    assert res_dense["MADGap"] == pytest.approx(ref["MADGap"])


def test_no_remote_nodes_returns_none_for_mad_rmt():
    # Small graph where diameter < 8 -> no remote nodes exist
    n = 5
    edges = [(i, i+1) for i in range(n-1)]  # path of length 4 => max hop = 4 < 8
    data = build_data(n, edges)
    batch = Batch.from_data_list([data])
    embeddings = torch.randn((n, 4), dtype=torch.float32)
    res = mad_madgap_batch(batch, embeddings, max_neigh_hop=3, min_remote_hop=8)
    # MAD_rmt should be None (no node pairs with hop >= 8)
    assert res["MAD_rmt"] is None
    # MAD_neb should be defined (there are many <=3 hops)
    assert res["MAD_neb"] is not None
    # MADGap should be None because rmt is None
    assert res["MADGap"] is None


def test_identical_embeddings_global_mad_zero():
    # Create a few small graphs, all nodes have identical embeddings -> global MAD 0
    data1 = build_data(3, [(0,1),(1,2)])
    data2 = build_data(4, [(0,1),(1,2),(2,3)])
    batch = Batch.from_data_list([data1, data2])
    embeddings = torch.ones((3+4, 6), dtype=torch.float32)  # all identical rows
    res = mad_madgap_batch(batch, embeddings)
    # global MAD should be 0
    assert res["MAD"] == pytest.approx(0.0, abs=1e-6)
    # neighbor MAD should also be 0 when pairs exist
    assert (res["MAD_neb"] == pytest.approx(0.0, abs=1e-6)) or (res["MAD_neb"] is None)


# Optional: a test for reproducibility / deterministic numerics
def test_reproducible_on_repeat_calls():
    n = 7
    edges = [(i, i+1) for i in range(n-1)]
    data = build_data(n, edges)
    batch = Batch.from_data_list([data])
    torch.manual_seed(123)
    embeddings = torch.randn((n, 4), dtype=torch.float32)
    a = mad_madgap_batch(batch, embeddings)
    b = mad_madgap_batch(batch, embeddings.clone())
    assert a == pytest.approx(b, abs=1e-6)
