import torch
import torch.nn.functional as F
from collections import deque
from typing import Optional, Dict, List, Tuple

from sklearn.metrics import pairwise_distances

from src.evaluation.metrics.cosine_distance import cosine_distance_torch


# helper: BFS from a single start node in adjacency list (returns list of distances, -1 unreachable)
def bfs_all_distances(adj: List[List[int]], start: int) -> dict[str, float]:
    n = len(adj)
    dist = [-1] * n
    q = deque()
    dist[start] = 0
    q.append(start)
    while q:
        u = q.popleft()
        for v in adj[u]:
            if dist[v] == -1:
                dist[v] = dist[u] + 1
                q.append(v)
    return dist


def mad_madgap_batch(
        batch,
        embeddings: torch.Tensor,
        max_neigh_hop: int = 3,
        min_remote_hop: int = 8,
        dense_memory_threshold: int = 10_000_000,
        eps: float = 1e-12,
) -> Dict[str, torch.Tensor]:
    """
    Compute MAD and MADGap for a PyG Batch and node embeddings.
    - batch: torch_geometric.data.Batch (has attributes 'batch' and 'edge_index')
    - embeddings: (N_total_nodes, embed_dim) tensor aligned with batch nodes

    Returns a dict with:
      - 'MAD'       : mean of per-graph global MAD
      - 'MAD_neb'   : mean of per-graph MAD computed on node-pairs with hop <= max_neigh_hop
      - 'MAD_rmt'   : mean of per-graph MAD computed on node-pairs with hop >= min_remote_hop
      - 'MADGap'    : MAD_rmt - MAD_neb (mean across graphs)
      - 'per_graph' : list of per-graph tuples (mad, mad_neb, mad_rmt, num_nodes)

    Implementation notes:
      - We compute per-graph values (avoids storing full batch pairwise matrices).
      - For each graph we:
         * build adjacency lists (CPU)
         * compute normalized embeddings and cosine-similarity by chunked matmul (device of embeddings)
         * run BFS from each node to obtain exact unweighted shortest-path distances (used to build masks)
         * compute per-node averages under masks and then per-graph MADs
      - If a graph has no valid neighbor pairs or no valid remote pairs, the corresponding mad is set to None
      - Graphs with <=1 node are skipped (returned as zeros / None appropriately)
    """
    device = embeddings.device
    batch_vec: torch.Tensor = batch.batch  # shape (N_total_nodes,)
    edge_index: torch.Tensor = batch.edge_index  # shape (2, N_edges) - global indices relative to batch

    # Prepare outputs
    per_graph_results: List[Tuple[Optional[float], Optional[float], Optional[float], Optional[float], int]] = []

    # iterate graphs by id
    graph_ids = torch.unique(batch_vec).tolist()  # ints


    for gid in graph_ids:
        # node indices for this graph in the batch
        node_mask = (batch_vec == gid)
        node_idx = torch.nonzero(node_mask, as_tuple=False).view(-1)
        n_g = node_idx.numel()
        if n_g <= 1:
            # trivial graph -> no pairwise distances
            per_graph_results.append((None, None, None, int(n_g)))
            continue

        # Local mapping: global index -> local index in [0, n_g)
        global_to_local = -torch.ones(batch_vec.size(0), dtype=torch.long, device=batch_vec.device)
        global_to_local[node_idx] = torch.arange(n_g, device=batch_vec.device)
        # Build adjacency list (CPU lists of python ints) for BFS
        # We need only edges internal to this graph
        e_src = edge_index[0]
        e_dst = edge_index[1]
        # mask edges where both endpoints belong to this graph
        src_mask = node_mask[e_src]
        dst_mask = node_mask[e_dst]
        internal_mask = src_mask & dst_mask
        # extract internal edge endpoints and convert to local indices (move to cpu for python lists)
        internal_src = e_src[internal_mask].cpu()
        internal_dst = e_dst[internal_mask].cpu()
        # number of local nodes n_g
        adj = [[] for _ in range(n_g)]
        # fill adjacency (undirected assumption: GNN graphs are typically undirected; if directed change accordingly)
        for s_global, d_global in zip(internal_src.tolist(), internal_dst.tolist()):
            s_local = int(global_to_local[s_global].item())
            d_local = int(global_to_local[d_global].item())
            # avoid self-loops duplication (still safe)
            if d_local not in adj[s_local]:
                adj[s_local].append(d_local)
            if s_local not in adj[d_local]:
                adj[d_local].append(s_local)
        # shape (n_g, d)

        # We'll compute pairwise cosine distances row-by-row in chunks if needed
        # threshold defines when full dense (n_g*n_g) is acceptable in memory
        nn_pairs = n_g * n_g
        if nn_pairs <= dense_memory_threshold:
            # test = embeddings[:5] + eps
            # test3 = test.cpu().numpy()
            dist_mat = cosine_distance_torch(embeddings[node_idx], embeddings[node_idx])
            dist_mat = dist_mat.to(device)
            # compute full similarity matrix at once (single matmul)
            # sim = torch.matmul(h_block, h_block.t())  # (n_g, n_g)
            # dist_mat = pairwise_distances(embeddings[node_idx].cpu(), embeddings[node_idx].cpu(), metric="cosine")
            # dist_mat = torch.tensor(dist_mat, device=device)

            # We will later index rows of dist_mat for each node
            chunk_rows = range(0, n_g, n_g)  # single chunk covering all
            chunk_size = n_g
            use_chunking = False
        else:
            # decide chunk size so chunk_rows * n_g <= dense_memory_threshold
            chunk_size = max(1, int(dense_memory_threshold // n_g))
            use_chunking = True

        # Accumulators for per-node averaged distances under masks
        per_node_neb_sum = [0.0] * n_g
        per_node_neb_count = [0] * n_g
        per_node_rmt_sum = [0.0] * n_g
        per_node_rmt_count = [0] * n_g

        # Extract embedding block and normalize (for chunking)
        h_block = embeddings[node_idx].to(device)
        h_block = F.normalize(h_block, p=2, dim=1, eps=eps)

        # we will iterate rows in chunks; for each row we compute BFS for that node and then sum distances for selected indices
        for start in range(0, n_g, chunk_size):
            end = min(n_g, start + chunk_size)
            rows_idx = list(range(start, end))
            # compute similarity rows for these nodes
            if use_chunking:
                h_chunk = h_block[rows_idx]  # (chunk, d)
                sim_chunk = torch.matmul(h_chunk, h_block.t())  # (chunk, n_g) on device
                sim_chunk = sim_chunk.clamp(-1.0, 1.0)
                dist_chunk = 1.0 - sim_chunk  # (chunk, n_g)
            else:
                # already computed full dist_mat
                dist_chunk = dist_mat[start:end, :]  # view

            # For each row in chunk compute BFS and aggregate masked sums
            for i_local_in_chunk, i_local in enumerate(rows_idx):
                # exact shortest-path distances from node i_local (local indexing)
                dists = bfs_all_distances(adj, i_local)  # python list of ints, -1 for unreachable

                # Build neighbor and remote index lists (exclude self)
                neb_indices = []
                rmt_indices = []
                for j, dd in enumerate(dists):
                    if j == i_local:
                        continue
                    if dd == -1:
                        # unreachable -> exclude (paper excluded infinite distances)
                        continue
                    if 1 <= dd <= max_neigh_hop:
                        neb_indices.append(j)
                    if dd >= min_remote_hop:
                        rmt_indices.append(j)

                # Convert lists to torch indexes (if empty keep counts 0)
                if neb_indices:
                    # dist_chunk row on device; select by index tensor
                    idx_tensor = torch.tensor(neb_indices, dtype=torch.long, device=device)
                    row_tensor = dist_chunk[i_local_in_chunk] if use_chunking else dist_chunk[i_local - start]
                    s = row_tensor.index_select(0, idx_tensor).sum().item()
                    per_node_neb_sum[i_local] = s
                    per_node_neb_count[i_local] = len(neb_indices)
                else:
                    per_node_neb_sum[i_local] = 0.0
                    per_node_neb_count[i_local] = 0

                if rmt_indices:
                    idx_tensor = torch.tensor(rmt_indices, dtype=torch.long, device=device)
                    row_tensor = dist_chunk[i_local_in_chunk] if use_chunking else dist_chunk[i_local - start]
                    s = row_tensor.index_select(0, idx_tensor).sum().item()
                    per_node_rmt_sum[i_local] = s
                    per_node_rmt_count[i_local] = len(rmt_indices)
                else:
                    per_node_rmt_sum[i_local] = 0.0
                    per_node_rmt_count[i_local] = 0

        # Now compute per-graph MADs
        # Global MAD (all pairs excluding self): we can compute from dist_mat if available; otherwise compute by chunks
        # Here compute global MAD per node by reusing h_block (we will compute as mean distance from node to all other reachable nodes)
        # To avoid recomputing BFS distances again, we use the neighbor/remote derived info only for the masks requested.
        # For global MAD we compute using dense matmul again but aggregated per-node.
        # Use chunking if needed.
        # Compute per-node global average distances and then mean across nodes (exclude unreachable self-only graphs)
        per_node_global_sum = [0.0] * n_g
        per_node_global_count = [0] * n_g

        # compute global distances row-by-row with same chunking as above
        for start in range(0, n_g, chunk_size):
            end = min(n_g, start + chunk_size)
            rows_idx = list(range(start, end))
            if use_chunking:
                h_chunk = h_block[rows_idx]
                sim_chunk = torch.matmul(h_chunk, h_block.t())
                sim_chunk = sim_chunk.clamp(-1.0, 1.0)
                dist_chunk = 1.0 - sim_chunk
            else:
                dist_chunk = dist_mat[start:end, :]

            for i_local_in_chunk, i_local in enumerate(rows_idx):
                # we need to exclude unreachable nodes: obtain BFS distances to know which nodes are reachable
                dists = bfs_all_distances(adj, i_local)
                reachable_indices = [j for j, dd in enumerate(dists) if dd != -1 and j != i_local]
                if reachable_indices:
                    idx_tensor = torch.tensor(reachable_indices, dtype=torch.long, device=device)
                    row_tensor = dist_chunk[i_local_in_chunk] if use_chunking else dist_chunk[i_local - start]
                    s = row_tensor.index_select(0, idx_tensor).sum().item()
                    per_node_global_sum[i_local] = s
                    per_node_global_count[i_local] = len(reachable_indices)
                else:
                    per_node_global_sum[i_local] = 0.0
                    per_node_global_count[i_local] = 0

        # compute average over nodes where counts>0
        # global MAD
        node_avgs_global = []
        for s, c in zip(per_node_global_sum, per_node_global_count):
            if c > 0:
                node_avgs_global.append(s / c)
        mad_global = float(sum(node_avgs_global) / len(node_avgs_global)) if node_avgs_global else None

        # neighbor MAD
        node_avgs_neb = []
        for s, c in zip(per_node_neb_sum, per_node_neb_count):
            if c > 0:
                node_avgs_neb.append(s / c)
        mad_neb = float(sum(node_avgs_neb) / len(node_avgs_neb)) if node_avgs_neb else None

        # remote MAD
        node_avgs_rmt = []
        for s, c in zip(per_node_rmt_sum, per_node_rmt_count):
            if c > 0:
                node_avgs_rmt.append(s / c)
        mad_rmt = float(sum(node_avgs_rmt) / len(node_avgs_rmt)) if node_avgs_rmt else None

        mad_gap = mad_rmt - mad_neb if node_avgs_rmt else None

        per_graph_results.append((mad_global, mad_neb, mad_rmt, mad_gap, int(n_g)))


    # Aggregate across graphs (mean over graphs that have value)
    def mean_over_graphs(values: List[Optional[float]]) -> Optional[float]:
        vals = [v for v in values if v is not None]
        return float(sum(vals) / len(vals)) if vals else None

    mad_vals = [g[0] for g in per_graph_results]
    mad_neb_vals = [g[1] for g in per_graph_results]
    mad_rmt_vals = [g[2] for g in per_graph_results]
    mad_gap_vals = [g[3] for g in per_graph_results]

    mad_mean = mean_over_graphs(mad_vals)
    mad_neb_mean = mean_over_graphs(mad_neb_vals)
    mad_rmt_mean = mean_over_graphs(mad_rmt_vals)
    mad_gap_mean = mean_over_graphs(mad_gap_vals) if (mad_neb_mean is not None and mad_rmt_mean is not None) else None

    return {
        "MAD": mad_mean,
        "MAD_neb": mad_neb_mean,
        "MAD_rmt": mad_rmt_mean,
        "MADGap": mad_gap_mean
    }
