from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Literal, Sequence, Tuple, Union
import subprocess
import tempfile
import numpy as np
import networkx as nx
import scipy.sparse as sps
from scipy.stats import spearmanr
from scipy.stats import rankdata
import os
import pandas as pd
from joblib import Parallel, delayed
from itertools import combinations

import netlsd
from netrd.distance import PortraitDivergence 

GraphLike = Union[nx.Graph, np.ndarray]
DistanceName = Literal["netlsd_heat", "netlsd_wave", "portrait_div", "gcd"]



def _to_nx(g: GraphLike) -> nx.Graph:
    if isinstance(g, nx.Graph):
        return g
    arr = np.asarray(g)
    n, m = arr.shape
    if n != m:
        raise ValueError("Adjacency must be square.")
    G = nx.from_numpy_array((arr > 0).astype(int))
    return G


def _lap_sparse(G: nx.Graph) -> sps.spmatrix:
    return nx.normalized_laplacian_matrix(G)


def _netlsd_signature(G: nx.Graph, kernel: Literal["heat", "wave"]) -> np.ndarray:
    if kernel == "heat":
        sig = netlsd.heat(G)
        # times = np.logspace(-2, 2, 250)
        # sig = netlsd.heat(
        #     G, timescales=times
        # )  # permutation/size-invariant trace signature
    elif kernel == "wave":
        # times = np.linspace(0.0, 2.0 * np.pi, 250, endpoint=False)
        sig = netlsd.wave(G)
    else:
        raise ValueError(kernel)
    return np.asarray(sig, dtype=float)


def _euclidean(x: np.ndarray, y: np.ndarray) -> float:
    z = x - y
    # return float(np.sqrt(np.dot(z, z)))
    return np.linalg.norm(z)


def parallelify(
    workers: Parallel,
    func: Callable,
    N: int,
) -> np.ndarray:
    D = np.zeros((N, N), dtype=float)
    vals = workers(
        delayed(lambda i, j: (i, j, func(i, j)))(i, j)
        for i, j in combinations(range(N), 2)
    )
    for i, j, dist in vals:
        D[i, j] = D[j, i] = dist

    return D


def pairwise_distance_matrix(
    graphs,
    distance,
    workers: Parallel,
    orca_path: str = "orca/orca",
    orca_prefix: str | None = None,
):
    Gs = [_to_nx(g) for g in graphs]
    N = len(Gs)
    D = np.zeros((N, N), dtype=float)

    if distance in ("netlsd_heat", "netlsd_wave"):
        kernel = "heat" if distance == "netlsd_heat" else "wave"
        sigs = [_netlsd_signature(G, kernel) for G in Gs] 
        for i in range(N):
            for j in range(i + 1, N):
                d = _euclidean(sigs[i], sigs[j])
                D[i, j] = D[j, i] = d
        return D 

    if distance == "portrait_div":
        pd = PortraitDivergence() 
        D = parallelify(workers, lambda i, j: float(pd.dist(Gs[i], Gs[j])), N)
        return D 

    if distance == "gcd":
        if orca_prefix is None:
            if os.path.basename(orca_path) == "orca" and os.path.exists(orca_path):
                orca_prefix = os.path.dirname(os.path.abspath(orca_path))
            else:
                orca_prefix = "orca"
        gcms = [
            _gcm_orca_like_theirs(G, orca_prefix=orca_prefix, graphlet_size=4)
            for G in Gs
        ]
        vecs = [gcm[np.triu_indices_from(gcm, k=1)] for gcm in gcms]
        for i in range(N):
            for j in range(i + 1, N):
                d = _euclidean(vecs[i], vecs[j])
                D[i, j] = D[j, i] = d
        return D

    raise ValueError(f"Unknown distance: {distance}")




def _normalize_rows(X: np.ndarray) -> np.ndarray:
    X = X.astype(float, copy=False)
    mu = X.mean(axis=1, keepdims=True)
    Xn = X - mu
    denom = np.sqrt((Xn * Xn).sum(axis=1, keepdims=True) + 1e-8)
    Xn /= denom
    return Xn


def _pearson_allpairs(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    Xn = _normalize_rows(X)
    Yn = _normalize_rows(Y)
    R = Xn @ Yn.T
    np.clip(R, -1.0, 1.0, out=R)
    return R


def _spearman_matrix(X: np.ndarray) -> np.ndarray:

    ranks = np.apply_along_axis(rankdata, 1, X).astype(float)
    Xn = _normalize_rows(ranks)
    R = Xn @ Xn.T
    np.clip(R, -1.0, 1.0, out=R)
    return R


def _nx_to_orca_edgelist(G: nx.Graph) -> tuple[list[tuple[int, int]], int]:
    mapping = {u: i for i, u in enumerate(G.nodes())}
    edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
    return edges, G.number_of_nodes()


def _run_orca_node(
    edge_list: list[tuple[int, int]],
    nodes_num: int,
    orca_prefix: str,
    graphlet_size: int = 4,
) -> np.ndarray:

    cmd = os.path.join(orca_prefix, "orca")
    if not os.path.exists(cmd):
        raise FileNotFoundError(f"ORCA binary not found at {cmd}")

    fd_in, in_file = tempfile.mkstemp()
    fd_out, out_file = tempfile.mkstemp()

    try:
        with os.fdopen(fd_in, "w") as fp:
            fp.write(f"{nodes_num} {len(edge_list)}\n")
            for u, v in edge_list:
                fp.write(f"{u} {v}\n")

        proc = subprocess.Popen(
            [cmd, "node", str(graphlet_size), in_file, out_file],
            stdin=None,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.PIPE,
            universal_newlines=True,
        )
        err = proc.stderr.read().strip()
        proc.stderr.close()
        proc.wait()
        if proc.returncode != 0:
            raise RuntimeError(f"ORCA failed: {err}")

        G = pd.read_table(out_file, header=None, sep=r"\s+")
        arr = G.values 
        if arr.ndim == 1:
            arr = arr[None, :]
        return arr.astype(float)
    finally:
        os.close(fd_out)
        # remove temp files
        for p in (in_file, out_file):
            try:
                os.remove(p)
            except OSError:
                pass


_GCM_ORDER = [0, 2, 5, 7, 8, 10, 11, 6, 9, 4, 1]


def _gcm_orca_like_theirs(
    G: nx.Graph, orca_prefix: str, graphlet_size: int = 4
) -> np.ndarray:
    edge_list, n = _nx_to_orca_edgelist(G)
    GDC = _run_orca_node(
        edge_list, nodes_num=n, orca_prefix=orca_prefix, graphlet_size=graphlet_size
    )
    GDC1 = GDC[:, _GCM_ORDER].T
    GCM = _spearman_matrix(GDC1)
    np.nan_to_num(GCM, copy=False)
    return GCM
