import torch
from sympy.physics.pring import energy
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected, remove_self_loops, degree

POSSIBLE_TYPES = ['rusch', 'cai']

def dirichlet_energy_pyg(
    batch: Data|Batch,
    types: str|list,
    embeddings: torch.Tensor = None,
    edge_weight: torch.Tensor = None,
    per_graph: bool = False
) -> dict[str, float]:
    """
    Compute Dirichlet Energy depending of different definitions. Results a dict of each value per definition (type).

    Args:
        batch: torch_geometric.data.Data with `edge_index` and optionally `x`, `edge_attr`, `batch`.
        types: what types to calculate. At least one of: ["rusch", "cai"]
        embeddings: optional (n, d) tensor. If None, uses data.x.
        edge_weight: optional (E,) tensor of edge weights. If None and data.edge_attr present,
                     `data.edge_attr` is used (must be 1D).
        per_graph: if True and `data.batch` exists, return a tensor of shape (num_graphs,)
                   with the energy computed **per graph** (each graph normalized by its node count).
                   If False (default), returns a single scalar normalized by total node count.

    Returns:
        dict: Energy value for each definition given as type.
    """

    H = embeddings if embeddings is not None else getattr(batch, 'x', None)
    if H is None:
        raise ValueError("No embeddings: pass `embeddings` or set `data.x` in the Data object.")
    if not hasattr(batch, 'edge_index') or batch.edge_index is None:
        raise ValueError("data.edge_index is required.")

    types = types if isinstance(types, list) else [types]
    for type in types:
        if type not in POSSIBLE_TYPES:
            raise ValueError(f"Dirichlet energy of type {type} not implemented.")
    results = {}

    device = H.device
    edge_index = batch.edge_index.to(device)

    # symmetrize and deduplicate edges (correctly handles the case where edge_index
    # already contains both directions)
    edge_index = to_undirected(edge_index)

    # remove self-loops and align weights
    edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

    # Move edges and weights to device
    edge_index = edge_index.to(device)
    if edge_weight is not None:
        edge_weight = edge_weight.to(device).float()

    # Get node and neighbour nodes
    v, u = edge_index                           # each is shape (E,)

    ##### Dirichlet Energy not normalised #####
    if "rusch" in types:
        # Compute per-edge squared L2 norm.
        diffs = H[v] - H[u]                           # (E, d)
        per_edge_sqnorm = diffs.pow(2)                # (E, d)
        # Take mean over features
        per_edge_sqnorm = per_edge_sqnorm.mean(dim=1) # (E, )

        # Apply weights
        if edge_weight is not None:
            per_edge_sqnorm = per_edge_sqnorm * edge_weight

        # Raw global scalar
        sq_sum_diff = per_edge_sqnorm.sum()
        sq_sum_diff = sq_sum_diff.item()

        # Normalise but node count like Rusch
        # Global scalar normalized by total node count
        n = H.size(0)
        energy_rusch = per_edge_sqnorm.sum() / float(n)
        energy_rusch = energy_rusch.sqrt()
        energy_rusch = energy_rusch.item()

        results.update({'dirichlet_energy_rusch': energy_rusch, 'sq_sum_diff':sq_sum_diff})

    ##### Dirichlet Energy by Cai #####
    if 'cai' in types:
        # Get degrees
        deg = degree(v, num_nodes=batch.num_nodes)                        # (N,)
        deg_v = deg[v][:, None]                                           # (E,)
        deg_u = deg[u][:, None]                                           # (E,)
        # Add 1 to degrees
        deg_v += 1
        deg_u += 1

        # Computer degree weighted distance
        degree_weighted_diffs = H[v]/deg_v.sqrt() - H[u]/deg_u.sqrt()     # (E, d)
        degree_weighted_diffs = degree_weighted_diffs.pow(2)              # (E, d)

        # Take mean over features
        degree_weighted_diffs = degree_weighted_diffs.mean(dim=1)         # (E,)

        # Apply edge weights
        if edge_weight is not None:
            degree_weighted_diffs = degree_weighted_diffs * edge_weight

        # Divide by 2
        energy_cai = 0.5 * degree_weighted_diffs.sum()
        energy_cai = energy_cai.item()
        results.update({'dirichlet_energy_cai': energy_cai})


    # Calculate energy per graph
    # if per_graph:
    #     batch = getattr(batch, 'batch', None)
    #     if batch is None:
    #         raise ValueError("per_graph=True requires `data.batch` to exist (batched Data).")
    #     batch = batch.to(device)
    #     # edges belong to graph indexed by batch[u] (same as batch[v] for valid batched graphs)
    #     graph_idx = batch[v]
    #     num_graphs = int(batch.max().item()) + 1
    #
    #     # sum per-edge contributions into per-graph numerators
    #     per_graph_numer = torch.zeros(num_graphs, device=device).index_add_(0, graph_idx, per_edge_sqnorm)
    #
    #     # normalize by node counts per graph
    #     node_counts = torch.bincount(batch, minlength=num_graphs).to(device).float()
    #     # avoid division by zero (though node_counts should be >0 for valid graphs)
    #     node_counts = torch.clamp(node_counts, min=1.0)
    #     per_graph_energy = per_graph_numer / node_counts
    #     return per_graph_energy


    return results



if "__main__" == __name__:

    # Example 1: simple 3-node chain, 1D embeddings
    edge_index = torch.tensor([[0, 1, 1, 2],
                               [1, 0, 2, 1]], dtype=torch.long)  # symmetric (both dirs present)
    H = torch.tensor([[0.0], [1.0], [2.0]])  # (3,1)
    data = Data(x=H, edge_index=edge_index)
    print(dirichlet_energy_pyg(data, types=['rusch', 'cai']))  # expected 4/3 ≈ 1.3333

    # Example 2: 3-node chain, 2D embeddings
    H2 = torch.tensor([[0.0, 0.0],
                       [1.0, 2.0],
                       [2.0, 4.0]])  # shape (3,2)
    data2 = Data(x=H2, edge_index=edge_index)
    # each edge squared norm is sum over dims; result is larger but computed the same way
    print(dirichlet_energy_pyg(data2))

    # Example 3: batched graphs (per_graph=True)
    # two graphs: graph0 nodes [0,1], graph1 nodes [2,3]
    edge_index_batch = torch.tensor([[0, 1, 2, 3],
                                     [1, 0, 3, 2]], dtype=torch.long)
    H_batch = torch.tensor([[0.0], [1.0], [0.0], [10.0]])  # (4,1)
    batch_vec = torch.tensor([0, 0, 1, 1], dtype=torch.long)
    data_batch = Data(x=H_batch, edge_index=edge_index_batch, batch=batch_vec)
    print(dirichlet_energy_pyg(data_batch, per_graph=True))  # tensor([..., ...])
