"""
Optimized distance computations with GPU acceleration for NetLSD metrics.
"""

from __future__ import annotations
from typing import Literal, Union, Callable
import subprocess
import tempfile
import numpy as np
import networkx as nx
import scipy.sparse as sps
from scipy.stats import rankdata
import os
import pandas as pd
from joblib import Parallel, delayed
from itertools import combinations

# Portrait Divergence via netrd
from netrd.distance import PortraitDivergence
from netrd.distance.portrait_divergence import portrait_divergence as _netrd_portrait_divergence
from netrd.utilities import entropy as netrd_entropy

# GPU implementations
import torch
from evaluate.distances_gpu import (
    netlsd_heat,
    netlsd_wave,
    nx_to_adjacency_tensor,
)

GraphLike = Union[nx.Graph, np.ndarray]
DistanceName = Literal["netlsd_heat", "netlsd_wave", "portrait_div", "gcd"]

_DEVICE = None


def _get_device() -> torch.device:
    global _DEVICE
    if _DEVICE is None:
        if torch.cuda.is_available():
            _DEVICE = torch.device("cuda")
        else:
            _DEVICE = torch.device("cpu")
    return _DEVICE


def set_device(device: str | torch.device):
    global _DEVICE
    if isinstance(device, str):
        _DEVICE = torch.device(device)
    else:
        _DEVICE = device



def _to_nx(g: GraphLike) -> nx.Graph:
    """Convert to NetworkX graph if needed."""
    if isinstance(g, nx.Graph):
        return g
    arr = np.asarray(g)
    n, m = arr.shape
    if n != m:
        raise ValueError("Adjacency must be square.")
    G = nx.from_numpy_array((arr > 0).astype(int))
    return G


def _euclidean(x: np.ndarray, y: np.ndarray) -> float:
    return float(np.linalg.norm(x - y))


def parallelify(
    workers: Parallel,
    func: Callable,
    N: int,
) -> np.ndarray:
    D = np.zeros((N, N), dtype=float)
    vals = workers(
        delayed(lambda i, j: (i, j, func(i, j)))(i, j)
        for i, j in combinations(range(N), 2)
    )
    for i, j, dist in vals:
        D[i, j] = D[j, i] = dist
    return D



def _netlsd_signature(G: nx.Graph, kernel: Literal["heat", "wave"]) -> np.ndarray:
    device = _get_device()
    A = nx_to_adjacency_tensor(G, device)
    
    if kernel == "heat":
        sig = netlsd_heat(A, G)
    elif kernel == "wave":
        sig = netlsd_wave(A, G)
    else:
        raise ValueError(f"Unknown kernel: {kernel}")
    
    return sig.cpu().numpy()


def _netlsd_signature_batch(
    graphs: list[nx.Graph],
    kernel: Literal["heat", "wave"],
) -> list[np.ndarray]:
    device = _get_device()
    signatures = []
    
    for G in graphs:
        A = nx_to_adjacency_tensor(G, device)
        
        if kernel == "heat":
            sig = netlsd_heat(A, G)
        else:
            sig = netlsd_wave(A, G)
        
        signatures.append(sig.cpu().numpy())
    
    return signatures

def _normalize_rows(X: np.ndarray) -> np.ndarray:
    X = X.astype(float, copy=False)
    mu = X.mean(axis=1, keepdims=True)
    Xn = X - mu
    denom = np.sqrt((Xn * Xn).sum(axis=1, keepdims=True) + 1e-8)
    Xn /= denom
    return Xn


def _spearman_matrix(X: np.ndarray) -> np.ndarray:
    ranks = np.apply_along_axis(rankdata, 1, X).astype(float)
    Xn = _normalize_rows(ranks)
    R = Xn @ Xn.T
    np.clip(R, -1.0, 1.0, out=R)
    return R


def _nx_to_orca_edgelist(G: nx.Graph) -> tuple[list[tuple[int, int]], int]:
    """Convert NetworkX graph to ORCA edge list format."""
    mapping = {u: i for i, u in enumerate(G.nodes())}
    edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
    return edges, G.number_of_nodes()


def _run_orca_node(
    edge_list: list[tuple[int, int]],
    nodes_num: int,
    orca_prefix: str,
    graphlet_size: int = 4,
) -> np.ndarray:
    """
    Run ORCA to compute graphlet counts.
    """
    cmd = os.path.join(orca_prefix, "orca")
    if not os.path.exists(cmd):
        raise FileNotFoundError(f"ORCA binary not found at {cmd}")

    fd_in, in_file = tempfile.mkstemp()
    fd_out, out_file = tempfile.mkstemp()

    try:
        with os.fdopen(fd_in, "w") as fp:
            fp.write(f"{nodes_num} {len(edge_list)}\n")
            for u, v in edge_list:
                fp.write(f"{u} {v}\n")

        proc = subprocess.Popen(
            [cmd, "node", str(graphlet_size), in_file, out_file],
            stdin=None,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.PIPE,
            universal_newlines=True,
        )
        err = proc.stderr.read().strip()
        proc.stderr.close()
        proc.wait()
        if proc.returncode != 0:
            raise RuntimeError(f"ORCA failed: {err}")

        G = pd.read_table(out_file, header=None, sep=r"\s+")
        arr = G.values
        if arr.ndim == 1:
            arr = arr[None, :]
        return arr.astype(float)
    finally:
        os.close(fd_out)
        for p in (in_file, out_file):
            try:
                os.remove(p)
            except OSError:
                pass


# GCM orbit order (matches the original implementation)
_GCM_ORDER = [0, 2, 5, 7, 8, 10, 11, 6, 9, 4, 1]


def _gcm_orca_like_theirs(
    G: nx.Graph, orca_prefix: str, graphlet_size: int = 4
) -> np.ndarray:
    edge_list, n = _nx_to_orca_edgelist(G)
    GDC = _run_orca_node(
        edge_list, nodes_num=n, orca_prefix=orca_prefix, graphlet_size=graphlet_size
    )
    GDC1 = GDC[:, _GCM_ORDER].T 
    GCM = _spearman_matrix(GDC1)
    np.nan_to_num(GCM, copy=False)
    return GCM



def compute_portrait_single(G: nx.Graph) -> np.ndarray:
    N = G.number_of_nodes()
    
    if N == 0:
        return np.array([[1.0]])
    
    # Handle disconnected graphs
    try:
        dia = nx.diameter(G)
    except nx.NetworkXError:
        # Graph is disconnected - use N as upper bound for diameter
        dia = N

    B = np.zeros((dia + 1, N))

    max_path = 1
    adj = G.adj

    for starting_node in G.nodes():
        nodes_visited = {starting_node: 0}
        search_queue = [starting_node]
        d = 1

        while search_queue:
            next_depth = []
            extend = next_depth.extend

            for n in search_queue:
                l = [i for i in adj[n] if i not in nodes_visited]
                extend(l)

                for j in l:
                    nodes_visited[j] = d

            search_queue = next_depth
            d += 1

        node_distances = nodes_visited.values()
        if len(node_distances) == 0:
            continue
            
        max_node_distances = max(node_distances)

        curr_max_path = max_node_distances
        if curr_max_path > max_path:
            max_path = curr_max_path

        dict_distribution = dict.fromkeys(node_distances, 0)
        for d in node_distances:
            dict_distribution[d] += 1

        for shell, count in dict_distribution.items():
            if shell <= dia and count < N:
                B[shell][count] += 1

        max_shell = dia
        while max_shell > max_node_distances:
            B[max_shell][0] += 1
            max_shell -= 1

    return B[: max_path + 1, :]


def compute_portraits_batch(
    graphs: list[nx.Graph],
    n_jobs: int = -1,
    verbose: bool = True,
) -> list[np.ndarray]:
    if verbose:
        from tqdm import tqdm
        print(f"  Computing portrait matrices for {len(graphs)} graphs...")
        
        portraits = Parallel(n_jobs=n_jobs, prefer="threads")(
            delayed(compute_portrait_single)(G)
            for G in tqdm(graphs, desc="  Portraits")
        )
    else:
        portraits = Parallel(n_jobs=n_jobs, prefer="threads")(
            delayed(compute_portrait_single)(G)
            for G in graphs
        )
    
    return portraits


def _portrait_divergence(B1: np.ndarray, B2: np.ndarray) -> float:
    return float(_netrd_portrait_divergence(B1, B2))



def _js_divergence_fast(P: np.ndarray, Q: np.ndarray) -> float:
    M = 0.5 * (P + Q)
    
    mask_p = P > 0
    kl_pm = np.sum(P[mask_p] * np.log2(P[mask_p] / M[mask_p]))
    
    mask_q = Q > 0
    kl_qm = np.sum(Q[mask_q] * np.log2(Q[mask_q] / M[mask_q]))
    
    return 0.5 * (kl_pm + kl_qm)


def _get_portrait_prob_distribution(B: np.ndarray, N: int = None) -> np.ndarray:
    if N is None:
        N = int(B[0, 1]) if B.shape[1] > 1 and B[0, 1] > 0 else B.shape[1]
    
    d, K = B.shape
    
    v = np.arange(0, K)
    f = (B * v).sum(axis=1)
    f_sum = f.sum()
    if f_sum > 0:
        P_L = f / f_sum
    else:
        P_L = np.zeros(d)
    
    P_KgL = B / N if N > 0 else B
    
    P_KaL = P_KgL * P_L[:, None]
    
    return P_KaL.ravel()


def preprocess_portraits_for_divergence(
    portraits: list[np.ndarray],
    verbose: bool = False,
) -> tuple[list[np.ndarray], tuple[int, int]]:
    if verbose:
        print(f"  Preprocessing {len(portraits)} portraits for fast divergence...")
    
    max_rows = max(B.shape[0] for B in portraits)
    
    max_cols = 0
    for B in portraits:
        if B.size > 0:
            nonzero_cols = np.nonzero(B)[1]
            if len(nonzero_cols) > 0:
                max_cols = max(max_cols, np.max(nonzero_cols) + 1)
    
    max_cols = max(max_cols, 1)
    
    common_shape = (max_rows, max_cols)
    
    if verbose:
        print(f"    Common shape: {common_shape}")
    
    distributions = []
    for B in portraits:
        padded = np.zeros(common_shape)
        rows, cols = B.shape
        cols_to_copy = min(cols, max_cols)
        padded[:rows, :cols_to_copy] = B[:, :cols_to_copy]
        
        dist = _get_portrait_prob_distribution(padded)
        distributions.append(dist)
    
    if verbose:
        print(f"    Distribution vector length: {len(distributions[0])}")
    
    return distributions, common_shape


def portrait_divergence_fast(dist1: np.ndarray, dist2: np.ndarray) -> float:
    return _js_divergence_fast(dist1, dist2)


import multiprocessing
_CPU_COUNT = multiprocessing.cpu_count()

GCD_PARALLEL_WORKERS = _CPU_COUNT


def _compute_gcd_embedding_single(G: nx.Graph, orca_path: str) -> np.ndarray:
    orca_prefix = os.path.dirname(os.path.abspath(orca_path))
    gcm = _gcm_orca_like_theirs(G, orca_prefix=orca_prefix, graphlet_size=4)
    return gcm[np.triu_indices_from(gcm, k=1)]


def compute_gcd_embeddings_parallel(
    graphs: list[nx.Graph],
    orca_path: str,
    n_workers: int | None = None,
    verbose: bool = True,
) -> list[np.ndarray]:
    if n_workers is None:
        n_workers = GCD_PARALLEL_WORKERS
    
    if verbose:
        from tqdm import tqdm
        print(f"  Computing GCD embeddings for {len(graphs)} graphs...")
        print(f"  Using {n_workers} parallel workers")
        
        embeddings = Parallel(n_jobs=n_workers, batch_size=1)(
            delayed(_compute_gcd_embedding_single)(G, orca_path)
            for G in tqdm(graphs, desc="  GCD embeddings")
        )
    else:
        embeddings = Parallel(n_jobs=n_workers, batch_size=1)(
            delayed(_compute_gcd_embedding_single)(G, orca_path)
            for G in graphs
        )
    
    return embeddings

def pairwise_distance_matrix(
    graphs,
    distance: DistanceName,
    workers: Parallel,
    orca_path: str = "orca/orca",
    orca_prefix: str | None = None,
) -> np.ndarray:
    Gs = [_to_nx(g) for g in graphs]
    N = len(Gs)
    D = np.zeros((N, N), dtype=float)

    if distance in ("netlsd_heat", "netlsd_wave"):
        kernel = "heat" if distance == "netlsd_heat" else "wave"
        
        sigs = _netlsd_signature_batch(Gs, kernel)
        
        for i in range(N):
            for j in range(i + 1, N):
                d = _euclidean(sigs[i], sigs[j])
                D[i, j] = D[j, i] = d
        return D

    if distance == "portrait_div":
        portraits = compute_portraits_batch(Gs, verbose=False)
        
        distributions, _ = preprocess_portraits_for_divergence(portraits, verbose=False)
        
        for i in range(N):
            for j in range(i + 1, N):
                d = portrait_divergence_fast(distributions[i], distributions[j])
                D[i, j] = D[j, i] = d
        return D

    if distance == "gcd":
        vecs = compute_gcd_embeddings_parallel(Gs, orca_path, verbose=False)
        
        for i in range(N):
            for j in range(i + 1, N):
                d = _euclidean(vecs[i], vecs[j])
                D[i, j] = D[j, i] = d
        return D

    raise ValueError(f"Unknown distance: {distance}")


# CONVENIENCE FUNCTIONS FOR POSTPROCESS_NEW.PY


def compute_embeddings_gpu(
    graphs: list[nx.Graph],
    metric: DistanceName,
    orca_path: str = "orca/orca",
    verbose: bool = True,
) -> np.ndarray:
    if metric == "portrait_div":
        raise ValueError("portrait_div should use incremental greedy selection")
    
    if verbose:
        print(f"  Computing {metric} embeddings for {len(graphs)} graphs...")
        device = _get_device()
        print(f"  Using device: {device}")
    
    if metric in ("netlsd_heat", "netlsd_wave"):
        kernel = "heat" if metric == "netlsd_heat" else "wave"
        embeddings = _netlsd_signature_batch(graphs, kernel)
        return np.array(embeddings)
    
    elif metric == "gcd":
        embeddings = compute_gcd_embeddings_parallel(graphs, orca_path, verbose=verbose)
        return np.array(embeddings)
    
    else:
        raise ValueError(f"Unknown metric: {metric}")
