import pytest
import torch
import numpy as np
from torch_geometric.data import Data, Batch
from sklearn.metrics import pairwise_distances

# Import your functions
from src.evaluation.smoothing_measures.mad import (
    mad_batch,
    mad_from_edge_index,
    build_data
)

# ---------- Helper to build adjacency matrix ----------
def adj_from_edge_index(edge_index, num_nodes):
    adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float32)
    adj[edge_index[0], edge_index[1]] = 1
    return adj.numpy()

# ---------- Fixtures ----------
@pytest.fixture
def simple_batch():
    """Build a small batch of two graphs with known structure."""
    data_list = build_data(num_nodes=4, edges=[(0, 1), (1, 2), (2, 3)], num_graphs=2)
    batch = Batch.from_data_list(data_list)
    embeddings = torch.tensor(
        [[1.0, 0.0],
         [0.0, 1.0],
         [1.0, 1.0],
         [0.5, 0.5],
         [1.0, 0.0],
         [0.0, 1.0],
         [1.0, 1.0],
         [0.5, 0.5]],
        dtype=torch.float32,
    )
    return batch, embeddings

# ---------- Helper MAD function (original implementation) -----------

def mad_value(embeddings, edge_index, distance_metric='cosine', target_idx=None):
    in_arr = embeddings.cpu().numpy()
    num_nodes = len(in_arr)

    dist_arr = pairwise_distances(in_arr, in_arr, metric=distance_metric)

    mad = dist_arr[edge_index[0], edge_index[1]].mean()

    return mad

def build_line_graph(num_nodes: int, num_graphs):
    """
    Build a torch_geometric.data.Data object.
    `edges` is a list of (u, v) undirected edges using 0-based local indices.
    """
    data = []
    edges = [(i, i + 1) for i in range(num_nodes - 1)]
    for i in range(num_graphs):
        if len(edges) == 0:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        else:
            # Duplicate edges to make undirected (both directions)
            src = []
            dst = []
            for u, v in edges:
                src.append(u); dst.append(v)
                src.append(v); dst.append(u)
            edge_index = torch.tensor([src, dst], dtype=torch.long)
        data.append(Data(edge_index=edge_index, num_nodes=num_nodes))
    return data

# ---------- Tests ----------

def test_mad_batch_matches_reference(simple_batch):
    """MAD computed via mad_batch should match manual mad_value computation."""
    batch, embeddings = simple_batch

    # Compute mad_batch result
    mad_test = mad_batch(batch, embeddings, max_neigh_hop=1)

    # Reference computation
    per_graph_mads = []
    for graph_id in torch.unique(batch.batch).tolist():
        emb = embeddings[batch.batch == graph_id]
        edge_index = batch[graph_id].edge_index

        mad_ref = mad_value(emb, edge_index)
        per_graph_mads.append(mad_ref)
    mad_true = sum(per_graph_mads) / len(per_graph_mads)

    assert np.isclose(mad_test, mad_true, atol=1e-2), \
        f"Expected {mad_true}, got {mad_test}"

def test_mad_random():
    """MAD computed via mad_batch should match manual mad_value computation."""
    n_nodes = 4
    data = build_line_graph(n_nodes, num_graphs=2)
    batch = Batch.from_data_list(data)
    embeddings = torch.randn((n_nodes*2, 64), dtype=torch.float32)

    # Compute mad_batch result
    mad_test = mad_batch(batch, embeddings, max_neigh_hop=1)

    # Reference computation
    per_graph_mads = []
    for graph_id in torch.unique(batch.batch).tolist():
        emb = embeddings[batch.batch == graph_id]
        edge_index = batch[graph_id].edge_index

        mad_ref = mad_value(emb, edge_index)
        per_graph_mads.append(mad_ref)
    mad_true = sum(per_graph_mads) / len(per_graph_mads)

    assert np.isclose(mad_test, mad_true, atol=1e-4), \
        f"Expected {mad_true}, got {mad_test}"

def test_mad_0():
    """MAD computed via mad_batch should match manual mad_value computation."""
    n_nodes = 4
    data = build_line_graph(n_nodes, num_graphs=2)
    batch = Batch.from_data_list(data)
    embeddings = torch.ones((n_nodes*2, 64), dtype=torch.float32)

    # Compute mad_batch result
    mad_test = mad_batch(batch, embeddings, max_neigh_hop=1)

    # Reference computation
    per_graph_mads = []
    for graph_id in torch.unique(batch.batch).tolist():
        emb = embeddings[batch.batch == graph_id]
        edge_index = batch[graph_id].edge_index

        mad_ref = mad_value(emb, edge_index)
        per_graph_mads.append(mad_ref)
    mad_true = sum(per_graph_mads) / len(per_graph_mads)

    assert mad_test == 0
    assert mad_true == 0
    assert np.isclose(mad_test, mad_true, atol=1e-4), \
        f"Expected {mad_true}, got {mad_test}"


def test_mad_1():
    """MAD computed via mad_batch should match manual mad_value computation."""
    n_nodes = 2
    data = build_line_graph(n_nodes, num_graphs=1)
    batch = Batch.from_data_list(data)
    embeddings = torch.tensor([
        [0,1],
        [1,0]
    ], dtype=torch.float32)

    # Compute mad_batch result
    mad_test = mad_batch(batch, embeddings, max_neigh_hop=1)

    # Reference computation
    per_graph_mads = []
    for graph_id in torch.unique(batch.batch).tolist():
        emb = embeddings[batch.batch == graph_id]
        edge_index = batch[graph_id].edge_index

        mad_ref = mad_value(emb, edge_index)
        per_graph_mads.append(mad_ref)
    mad_true = sum(per_graph_mads) / len(per_graph_mads)

    assert np.isclose(mad_test, 1, atol=1e-4)
    assert np.isclose(mad_true, 1, atol=1e-4)
    assert np.isclose(mad_test, mad_true, atol=1e-4), \
        f"Expected {mad_true}, got {mad_test}"


def test_mad_2():
    """MAD computed via mad_batch should match manual mad_value computation."""
    n_nodes = 2
    data = build_line_graph(n_nodes, num_graphs=1)
    batch = Batch.from_data_list(data)
    embeddings = torch.tensor([
        [-1,-1],
        [1,1]
    ], dtype=torch.float32)

    # Compute mad_batch result
    mad_test = mad_batch(batch, embeddings, max_neigh_hop=1)

    # Reference computation
    per_graph_mads = []
    for graph_id in torch.unique(batch.batch).tolist():
        emb = embeddings[batch.batch == graph_id]
        edge_index = batch[graph_id].edge_index

        mad_ref = mad_value(emb, edge_index)
        per_graph_mads.append(mad_ref)
    mad_true = sum(per_graph_mads) / len(per_graph_mads)

    assert np.isclose(mad_test, 2, atol=1e-4)
    assert np.isclose(mad_true, 2, atol=1e-4)
    assert np.isclose(mad_test, mad_true, atol=1e-4), \
        f"Expected {mad_true}, got {mad_test}"



def test_empty_graph():
    """Handle empty graphs gracefully."""
    data = build_data(num_nodes=3, edges=[], num_graphs=1)
    batch = Batch.from_data_list(data)
    embeddings = torch.randn((3, 4))
    mad = mad_batch(batch, embeddings)
    assert mad == 0.0


def test_single_node_graph():
    """Single node should yield MAD = 0.0."""
    data = build_data(num_nodes=1, edges=[], num_graphs=1)
    batch = Batch.from_data_list(data)
    embeddings = torch.randn((1, 8))
    mad = mad_batch(batch, embeddings)
    assert mad == 0.0


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
def test_gpu_consistency(simple_batch):
    """MAD on GPU should match CPU."""
    batch, embeddings = simple_batch
    cpu_result = mad_batch(batch, embeddings, max_neigh_hop=1)
    gpu_result = mad_batch(batch.to("cuda"), embeddings.to("cuda"), max_neigh_hop=1)
    assert np.isclose(cpu_result, gpu_result, atol=1e-5)

