import torch
from sympy.physics.pring import energy
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected, remove_self_loops, degree


def wu_smoothness(
    batch: Data|Batch,
    embeddings: torch.Tensor = None,
    chunk_size = None
) -> dict[str, float]:
    """
    Compute Dirichlet Energy depending of different definitions. Results a dict of each value per definition (type).

    Args:
        batch: torch_geometric.data.Data with `edge_index` and optionally `x`, `edge_attr`, `batch`.
        embeddings: optional (n, d) tensor. If None, uses data.x.
        chunk_size: optional long. If None (default), set to length of embeddings.

    Returns:
        dict: Energy value for each definition given as type.
    """

    # iterate graphs by id
    batch_vec: torch.Tensor = batch.batch  # shape (N_total_nodes,)
    graph_ids = torch.unique(batch_vec).tolist()  # ints

    per_graph_score = []

    for graph_id in graph_ids:

        H = embeddings[batch_vec==graph_id] if embeddings is not None else getattr(batch[graph_id], 'x', None)
        if H is None:
            raise ValueError("No embeddings: pass `embeddings` or set `data.x` in the Data object.")
        if not hasattr(batch, 'edge_index') or batch.edge_index is None:
            raise ValueError("data.edge_index is required.")

        device = H.device

        N = len(H)
        d = int(H.shape[1])

        if chunk_size is None:
            chunk_size = int(N)

        finished = False

        while not finished and chunk_size >= 1:
            try:
                if d <= chunk_size:
                    # Transform H
                    ones = torch.matmul(torch.ones(N,1).to(device), torch.ones(N,1).to(device).T)
                    H_2 = torch.matmul(ones, H) / N

                    # Get differences
                    diff = H - H_2

                    # Frobenius Norm
                    distance = diff.pow(2).sum().sqrt()
                    distance = distance.item()
                else:
                    squared_sums = []
                    for s in range(0, d, chunk_size):
                        e = min(s + chunk_size, d)

                        H_chunk = H[:, s:e]

                        H_chunk_feature_sum = H_chunk.sum(dim=0)

                        H_chunk_2 = torch.matmul(torch.ones(N,1).to(device),H_chunk_feature_sum.unsqueeze(dim=1).T) / N

                        diff_chunk = H_chunk - H_chunk_2

                        diff_sq_sum_chunk = diff_chunk.pow(2).sum()
                        squared_sums.append(diff_sq_sum_chunk.item())

                    distance = sum(squared_sums) ** 0.5

                finished = True

            except Exception as e:
                # Detect CUDA OOM (either specific exception type or RuntimeError text)
                err_str = str(e).lower()
                is_oom = False
                if isinstance(e, torch.cuda.OutOfMemoryError):
                    is_oom = True
                elif isinstance(e, RuntimeError) and "out of memory" in err_str:
                    is_oom = True

                if not is_oom:
                    # re-raise unexpected exceptions
                    raise

                # Fallback: chunked processing on GPU
                if device.type == "cuda":
                    # free anything left from the failed attempt
                    try:
                        del ones, H_2, distance, H_chunk, H_chunk_2, H_chunk_feature_sum
                    except Exception:
                        pass
                    torch.cuda.empty_cache()

                chunk_size = int(chunk_size // 2)

        per_graph_score.append(distance)

    distance = sum(per_graph_score) / len(per_graph_score)

    return {"wu_smoothness": distance}



if "__main__" == __name__:

    # Example 1: simple 3-node chain, 1D embeddings
    edge_index = torch.tensor([[0, 1, 1, 2],
                               [1, 0, 2, 1]], dtype=torch.long)  # symmetric (both dirs present)
    H = torch.tensor([[0.0,], [1.0], [2.0]])  # (3,1)
    data = Data(x=H, edge_index=edge_index)
    batch = Batch.from_data_list([data])
    print(wu_smoothness(batch))  # expected 4/3 ≈ 1.3333

    # Example 2: 3-node chain, 2D embeddings
    H2 = torch.tensor([[0.0, 0.0],
                       [1.0, 2.0],
                       [2.0, 4.0]])  # shape (3,2)
    data2 = Data(x=H2, edge_index=edge_index)
    batch = Batch.from_data_list([data2])

    # each edge squared norm is sum over dims; result is larger but computed the same way
    print(wu_smoothness(batch))

    # Example 3: batched graphs (per_graph=True)
    # two graphs: graph0 nodes [0,1], graph1 nodes [2,3]
    edge_index_batch = torch.tensor([[0, 1, 2, 3],
                                     [1, 0, 3, 2]], dtype=torch.long)
    H_batch = torch.tensor([[0.0], [1.0], [0.0], [10.0]])  # (4,1)
    batch_vec = torch.tensor([0, 0, 1, 1], dtype=torch.long)
    data_batch = Data(x=H_batch, edge_index=edge_index_batch, batch=batch_vec)
    batch = Batch.from_data_list([data_batch])
    print(wu_smoothness(batch))  # tensor([..., ...])

    # Example 4: 3-node chain, 2D embeddings
    H2 = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                       [1.0, 2.0, 1.0, 2.0, 1.0, 2.0],
                       [2.0, 4.0, 2.0, 4.0, 2.0, 4.0]])  # shape (3,2)
    data2 = Data(x=H2, edge_index=edge_index)
    batch = Batch.from_data_list([data2])

    # each edge squared norm is sum over dims; result is larger but computed the same way
    print(wu_smoothness(batch, chunk_size=2))
