import copy
import time
from typing import Optional

import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Batch

from src.datasets.rewiring.khop import k_hop_pairs_batch_sparse


# requires: pip install scipy

def mad_from_edge_index(
    embeddings: torch.Tensor,
    edge_index: torch.Tensor,
    *,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    chunk_size: int = 100_000,
    eps: float = 1e-12,
) -> float:
    """
    Compute MAD = mean_{nodes} mean_{neighbor v of node} cosine_distance(node, v)
    using only the sparse edge list edge_index (shape (2, E)).
    - embeddings: (N, D) tensor
    - edge_index: (2, E) long tensor with local node indices 0..N-1
    - target_nodes: optional boolean mask (N,) or indices (K,) to restrict averaging like your target_idx
    Returns python float (rounded to digit_num decimals).
    """
    E = edge_index.shape[1]
    if E == 0:
        return 0.0

    if device is None:
        device = embeddings.device
    embeddings = embeddings.to(device)
    if dtype is not None:
        embeddings = embeddings.to(dtype)

    # L2 normalize (avoid dividing by zero)
    embeddings = F.normalize(embeddings, p=2, dim=1, eps=1e-12)

    # Compute all at once
    try:
        src = edge_index[0].to(device, non_blocking=True)
        dst = edge_index[1].to(device, non_blocking=True)

        # gather embeddings for all edges (vectorized)
        emb_src = embeddings[src]   # shape (E, D)
        emb_dst = embeddings[dst]   # shape (E, D)

        # cosine similarity is dot product when normalized
        cos_sim = (emb_src * emb_dst).sum(dim=1)        # shape (E,)
        cos_sim = cos_sim.clamp(min=-1.0, max= 1.0)
        cos_dist = 1.0 - cos_sim                        # shape (E,)

        mad = cos_dist.mean().item()

        return mad

    except Exception as e:
        # Detect CUDA OOM (either specific exception type or RuntimeError text)
        err_str = str(e).lower()
        is_oom = False
        if isinstance(e, torch.cuda.OutOfMemoryError):
            is_oom = True
        elif isinstance(e, RuntimeError) and "out of memory" in err_str:
            is_oom = True

        if not is_oom:
            # re-raise unexpected exceptions
            raise

        # Fallback: chunked processing on GPU
        if device.type == "cuda":
            # free anything left from the failed attempt
            try:
                del src, dst, emb_src, emb_dst, cos_sim, cos_dist
            except Exception:
                pass
            torch.cuda.empty_cache()

        total_sum = torch.tensor(0.0, device=device, dtype=embeddings.dtype)
        total_count = 0

        # iterate slices of edges
        for start in range(0, E, chunk_size):
            end = min(start + chunk_size, E)
            # bring only the slice of indices to device
            src_chunk = edge_index[0, start:end].to(device, non_blocking=True)
            dst_chunk = edge_index[1, start:end].to(device, non_blocking=True)

            # gather embeddings for this chunk
            emb_src = embeddings[src_chunk]  # (chunk, D)
            emb_dst = embeddings[dst_chunk]  # (chunk, D)

            # compute chunk cosine distances
            cos_sim_chunk = (emb_src * emb_dst).sum(dim=1)  # (chunk,)
            cos_sim_chunk = cos_sim_chunk.clamp(min=-1.0, max= 1.0)
            cos_dist_chunk = 1.0 - cos_sim_chunk

            # accumulate
            total_sum += cos_dist_chunk.sum()
            total_count += cos_dist_chunk.numel()

            # free chunk temporaries and occasionally empty cache
            del src_chunk, dst_chunk, emb_src, emb_dst, cos_sim_chunk, cos_dist_chunk
            if device.type == "cuda":
                # only occasionally empty cache to avoid slowing down too much
                # here: empty every ~16 chunks (tunable)
                if ((start // chunk_size) & 0xF) == 0:
                    torch.cuda.empty_cache()

        if total_count == 0:
            return 0.0
        mad = (total_sum / float(total_count)).item()

        # final cleanup
        try:
            del total_sum
        except Exception:
            pass
        if device.type == "cuda":
            torch.cuda.empty_cache()

        return float(mad)



def mad_batch(
        batch,
        embeddings: torch.Tensor,
        max_neigh_hop: int = 1,
        eps: float = 1e-12,
    ) -> dict[str, float]:

    device = embeddings.device
    batch_vec: torch.Tensor = batch.batch  # shape (N_total_nodes,)
    edge_index: torch.Tensor = batch.edge_index  # shape (2, N_edges) - global indices relative to batch

    # iterate graphs by id
    graph_ids = torch.unique(batch_vec).tolist()  # ints

    per_graph_mad = []

    for graph_id in graph_ids:
        embeddings_graph = embeddings[batch_vec==graph_id]
        n_nodes = len(embeddings_graph)
        edge_index = batch[graph_id].edge_index

        if max_neigh_hop != 1:
            edge_index = k_hop_pairs_batch_sparse(edge_index=edge_index,
                                                  num_nodes=n_nodes,
                                                  k=max_neigh_hop)

        mad_val = mad_from_edge_index(
            embeddings=embeddings_graph,
            edge_index=edge_index,
            device=device,
            eps=eps,
        )

        per_graph_mad.append(mad_val)

    mad = sum(per_graph_mad) / len(per_graph_mad)

    return mad


def build_data(num_nodes: int, edges: list, num_graphs):
    """
    Build a torch_geometric.data.Data object.
    `edges` is a list of (u, v) undirected edges using 0-based local indices.
    """
    data = []
    for i in range(num_graphs):
        if len(edges) == 0:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        else:
            # Duplicate edges to make undirected (both directions)
            src = []
            dst = []
            for u, v in edges:
                src.append(u); dst.append(v)
                src.append(v); dst.append(u)
            edge_index = torch.tensor([src, dst], dtype=torch.long)
        data.append(Data(edge_index=edge_index, num_nodes=num_nodes))
    return data

if __name__ == "__main__":
    n_nodes = 12800
    data = build_data(n_nodes, edges=[(i, i+1) for i in range(n_nodes-1)], num_graphs=2)
    batch = Batch.from_data_list(data)
    embeddings = torch.randn((n_nodes*2, 64), dtype=torch.float32)

    start = time.time()
    score = mad_batch(batch, embeddings, max_neigh_hop=3)
    end = time.time()
    run_time = end - start

    # print(f"Not patched --- Run time: {score[0]['not_patched_time']:>12.5f}s Score: {score[0]['not_patched']:>7.5f}")
    # print(f"Patched ------- Run time: {score[0]['patched_time']:>12.5f}s Score: {score[0]['patched']:>7.5f}")
    # print(f"Indexed ------- Run time: {score[0]['index_time']:>12.5f}s Score: {score[0]['index']:>7.5f}")

    print(f"Run time: {run_time:.5f}s Score: {score:.2f}")