import math
import torch
import pytest
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected, remove_self_loops

# import the function under test; adjust import path if needed
from src.evaluation.smoothing_measures.dirichlet_energy import dirichlet_energy_pyg


# ---- helper functions used by tests to compute expected values ----
def _to_undirected_expand(u, v, edge_weight=None):
    """
    Equivalent expansion used inside dirichlet_energy_pyg:
    concatenate original and flipped edges (u->v and v->u).
    Returns (u_all, v_all, edge_weight_all_or_None)
    """
    u_all = torch.cat([u, v], dim=0)
    v_all = torch.cat([v, u], dim=0)
    if edge_weight is None:
        ew_all = None
    else:
        ew_all = torch.cat([edge_weight, edge_weight], dim=0)
    return u_all, v_all, ew_all


def _remove_self_loops_mask(u, v, edge_weight=None):
    mask = (u != v)
    if edge_weight is None:
        return u[mask], v[mask], None
    else:
        return u[mask], v[mask], edge_weight[mask]


def expected_global_dirichlet(H, edge_index, edge_weight=None):
    """
    Compute expected global Dirichlet as implemented:
       per_edge_sqnorm = mean over feature dims of (H[u]-H[v])**2
       after undirecting (double directions) and removing self-loops
       energy = sqrt( sum(per_edge_sqnorm * weight_if_any) / n )
    """
    device = H.device
    ew = None if edge_weight is None else edge_weight.to(device).float()

    edge_index, ew_all = to_undirected(edge_index, ew)
    edge_index, ew_all = remove_self_loops(edge_index, edge_weight)

    u = edge_index[0].to(device)
    v = edge_index[1].to(device)

    if u.numel() == 0:
        per_edge_sqnorm = torch.empty(0, device=device)
    else:
        diffs = H[u] - H[v]  # (E', d)
        per_edge_sqnorm = diffs.pow(2).mean(dim=1)
        if ew_all is not None:
            per_edge_sqnorm = per_edge_sqnorm * ew_all

    n = float(H.size(0))
    energy = per_edge_sqnorm.sum() / n
    energy = energy.clamp(min=0.0).sqrt().item()
    return energy


def expected_per_graph(H, edge_index, batch_vec, edge_weight=None):
    """
    Compute expected per-graph vector returned by dirichlet_energy_pyg when per_graph=True:
    - Expand undirected, remove self-loops
    - accumulate per_edge_sqnorm into graph indexed by batch[u]
    - divide by node_counts (no sqrt)
    """
    device = H.device
    u = edge_index[0].to(device)
    v = edge_index[1].to(device)
    ew = None if edge_weight is None else edge_weight.to(device).float()
    u_all, v_all, ew_all = _to_undirected_expand(u, v, ew)
    u_all, v_all, ew_all = _remove_self_loops_mask(u_all, v_all, ew_all)

    if u_all.numel() == 0:
        # still must return zeros vector of length num_graphs
        num_graphs = int(batch_vec.max().item()) + 1
        return torch.zeros(num_graphs, device=device, dtype=torch.float32)

    diffs = H[u_all] - H[v_all]
    per_edge_sqnorm = diffs.pow(2).mean(dim=1)
    if ew_all is not None:
        per_edge_sqnorm = per_edge_sqnorm * ew_all

    graph_idx = batch_vec[u_all]
    num_graphs = int(batch_vec.max().item()) + 1
    per_graph_numer = torch.zeros(num_graphs, device=device).index_add_(0, graph_idx, per_edge_sqnorm)
    node_counts = torch.bincount(batch_vec, minlength=num_graphs).to(device).float()
    node_counts = torch.clamp(node_counts, min=1.0)
    per_graph_energy = per_graph_numer / node_counts
    return per_graph_energy


# ---- Tests ----

def test_raises_when_no_embeddings_and_no_data_x():
    # Data with edge_index but no x and no embeddings provided -> ValueError
    edge_index = torch.tensor([[0, 1],
                               [1, 0]])
    data = Data(edge_index=edge_index)
    with pytest.raises(ValueError, match="No embeddings"):
        dirichlet_energy_pyg(data)


def test_raises_when_no_edge_index():
    # Data with x but missing edge_index -> ValueError
    x = torch.randn(3, 4)
    data = Data(x=x)
    with pytest.raises(ValueError, match="data.edge_index is required"):
        dirichlet_energy_pyg(data)


def test_simple_two_node_edge_global_value():
    # Two nodes with a single undirected edge (0-1).
    # embeddings: 2D to test mean across dims
    H = torch.tensor([[0.0, 0.0], [1.0, 0.0]])
    # create edge_index with single directed edge; function will make undirected internally
    edge_index = torch.tensor([[0, 1],
                               [1, 0]])
    data = Data(x=H.clone(), edge_index=edge_index)

    out = dirichlet_energy_pyg(data)
    expected = expected_global_dirichlet(H, edge_index)
    assert isinstance(out, dict) and "dirichlet_energy_rusch" in out
    assert out["dirichlet_energy_rusch"] == pytest.approx(expected, abs=1e-8)


def test_embeddings_argument_overrides_data_x():
    H_data = torch.zeros(2, 1)  # if used would give 0 energy
    H_override = torch.tensor([[0.0], [3.0]])
    edge_index = torch.tensor([[0, 1],
                               [1, 0]])
    data = Data(x=H_data, edge_index=edge_index)

    out_with_override = dirichlet_energy_pyg(data, embeddings=H_override)
    out_without_override = dirichlet_energy_pyg(data)  # uses data.x

    # ensure the two results differ and override used
    assert out_without_override["dirichlet_energy_rusch"] == 0.0
    assert out_with_override["dirichlet_energy_rusch"] > 0.0


def test_self_loops_are_removed():
    # create a graph with a self-loop at node 0 and an edge 0-1; self-loop should be ignored
    H = torch.tensor([[0.0], [2.0]])
    edge_index = torch.tensor([[0, 0, 1],
                               [0, 1, 0]])  # edges: 0->0 (self), 0->1, 1->0
    data = Data(x=H, edge_index=edge_index)

    # Expected: only the 0<->1 pair contributes (self-loop excluded)
    expected = expected_global_dirichlet(H, edge_index)
    out = dirichlet_energy_pyg(data)
    assert out["dirichlet_energy_rusch"] == pytest.approx(expected, rel=1e-6)


def test_edge_weights_are_applied_and_int_casted():
    # Single undirected edge 0-1, but we provide integer weights for the directed edges
    H = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
    # single directed edge (0,1) -- function will make undirected by mirroring
    edge_index = torch.tensor([[0,1],
                               [1,0]])
    # original edge count E = 1 -> after internal to_undirected we will have 2 edges,
    # so pass weights for both directions by concatenating the original weight twice.
    w_orig = torch.tensor([2], dtype=torch.int32)
    w_all = torch.cat([w_orig, w_orig]).float()  # expected shape after to_undirected
    data = Data(x=H, edge_index=edge_index)

    # call function with edge_weight of length 2 (matching doubled edges)
    out = dirichlet_energy_pyg(data, edge_weight=w_all)
    expected = expected_global_dirichlet(H, edge_index, w_orig)
    assert out["dirichlet_energy_rusch"] == pytest.approx(expected, rel=1e-6)


def test_per_graph_computation_matches_manual():
    # Build a batched Data with two small graphs each having an internal edge (0-1) and (2-3)
    H = torch.tensor([
        [0.0],  # node 0 -- graph 0
        [1.0],  # node 1 -- graph 0
        [2.0],  # node 2 -- graph 1
        [4.0],  # node 3 -- graph 1
    ])
    # edges: 0-1 and 2-3 (single directed entries)
    edge_index = torch.tensor([[0, 2],
                               [1, 3]])
    batch_vec = torch.tensor([0, 0, 1, 1])

    data = Data(x=H, edge_index=edge_index, batch=batch_vec)
    out = dirichlet_energy_pyg(data, per_graph=True)
    expected = expected_per_graph(H, edge_index, batch_vec)
    assert isinstance(out, torch.Tensor)
    assert out["dirichlet_energy_rusch"][0] == pytest.approx(expected[0], rel=1e-6)
    assert out["dirichlet_energy_rusch"][1] == pytest.approx(expected[1], rel=1e-6)


def test_per_graph_requires_batch_attribute():
    H = torch.randn(3, 4)
    edge_index = torch.tensor([[0, 1],
                               [1, 2]])
    data = Data(x=H, edge_index=edge_index)  # no `batch`
    with pytest.raises(ValueError, match="per_graph=True requires `data.batch` to exist"):
        dirichlet_energy_pyg(data, per_graph=True)


def test_empty_edge_index_returns_zero():
    H = torch.randn(5, 3)
    edge_index = torch.empty((2, 0), dtype=torch.long)
    data = Data(x=H, edge_index=edge_index)
    out = dirichlet_energy_pyg(data)
    assert out["dirichlet_energy_rusch"] == 0.0


def test_cross_graph_edges_assigned_to_source_graph():
    # Create two graphs in a batch and add one cross-graph edge 0->2
    H = torch.tensor([[0.0], [1.0], [5.0], [6.0]])
    # edges: single directed 0->2 (will be mirrored to 2->0)
    edge_index = torch.tensor([[0],
                               [2]])
    batch_vec = torch.tensor([0, 0, 1, 1])
    data = Data(x=H, edge_index=edge_index, batch=batch_vec)

    out = dirichlet_energy_pyg(data, per_graph=True)
    # manually compute: after doubling -> edges (0->2) and (2->0). contributions:
    # contribution to graph[0] comes from u==0 (0->2), to graph[1] from u==2 (2->0).
    expected = expected_per_graph(H, edge_index, batch_vec)
    torch.testing.assert_close(out, expected)


def test_per_edge_sqnorm_is_mean_over_dims_not_sum():
    # Build embeddings where difference in dims is not symmetric to show mean vs sum
    H = torch.tensor([[0.0, 0.0], [1.0, 3.0]])  # diff = [-1, -3] squared -> [1,9] mean = 5
    edge_index = torch.tensor([[0],
                               [1]])
    data = Data(x=H, edge_index=edge_index)

    out = dirichlet_energy_pyg(data)
    expected = expected_global_dirichlet(H, edge_index)
    # expected value computed uses mean over dims, not sum
    assert out["dirichlet_energy_rusch"] == pytest.approx(expected, rel=1e-6)


def test_integer_edge_weights_length_must_match_expanded_edges():
    # This test verifies that passing weights that match the doubled edge list works
    H = torch.tensor([[0.0], [10.0]])
    edge_index = torch.tensor([[0],
                               [1]])  # single directed edge
    # if original E=1 then expanded E' = 2, so pass two weights
    weights = torch.tensor([3.0, 3.0])  # floats already; using same scalar per direction
    data = Data(x=H, edge_index=edge_index)
    out = dirichlet_energy_pyg(data, edge_weight=weights)
    expected = expected_global_dirichlet(H, edge_index, torch.tensor([3.0]))
    assert out["dirichlet_energy_rusch"] == pytest.approx(expected, rel=1e-6)

def test_low_energy_case():
    # Fully connected graph of 4 nodes with very small differences in embeddings
    H = torch.tensor([[1.0, 1.0], [1.001, 1.0], [1.0, 1.002], [1.001, 1.001]])
    edge_index = torch.combinations(torch.arange(4), r=2).t()  # all pairs (undirected)
    data = Data(x=H, edge_index=edge_index)

    out = dirichlet_energy_pyg(data)
    expected = expected_global_dirichlet(H, edge_index)
    # Should be extremely small (close to zero)
    assert out["dirichlet_energy_rusch"] < 1e-2
    assert math.isclose(out["dirichlet_energy_rusch"], expected, rel_tol=1e-6)

def test_high_energy_case():
    # Star graph: node 0 connected to 1..4
    # Node 0 has embedding far from others -> high energy
    H = torch.tensor([
        [0.0, 0.0],   # center
        [100.0, 100.0],
        [90.0, 100.0],
        [110.0, 90.0],
        [95.0, 105.0]
    ])
    edge_index = torch.tensor([
        [0, 0, 0, 0],
        [1, 2, 3, 4]
    ])
    data = Data(x=H, edge_index=edge_index)

    out = dirichlet_energy_pyg(data)
    expected = expected_global_dirichlet(H, edge_index)
    # Should be very large (>> 1)
    assert out["dirichlet_energy_rusch"] > 10.0
    assert math.isclose(out["dirichlet_energy_rusch"], expected, rel_tol=1e-6)

