"""
grad_graph_theo.py

compute
  1) diffusion-map functional gradient embeddings (with Procrustes alignment to a reference template)
  2) graph-level graph-theoretic metrics from functional connectivity matrices

dependencies
  - numpy
  - brainspace (for GradientMaps)
  - procrustes (for generalized Procrustes alignment)
  - bctpy (imported as `bct`) for graph metrics
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np


def _as_3d_connectomes(connectivity: np.ndarray) -> np.ndarray:
    """ensure input connectivity is float array with shape (S, N, N)."""
    conn = np.asarray(connectivity, dtype=float)
    if conn.ndim != 3:
        raise ValueError(f"`connectivity` shaped (S, N, N); got {conn.shape}")
    if conn.shape[-1] != conn.shape[-2]:
        raise ValueError(f"`connectivity` square on last two dims; got {conn.shape}")
    return conn


def _symmetrize_zero_diag_positive(W: np.ndarray) -> np.ndarray:
    """symmetrize, zero diagonal,"""
    W = np.array(W, dtype=float, copy=True)
    W = np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0)
    W = (W + W.T) / 2.0
    np.fill_diagonal(W, 0.0)
    W[W < 0] = 0.0
    return W


# diffusion-map gradients + alignment

@dataclass
class DMParams:
    dm_global_mean: np.ndarray  # (D,)
    dm_global_std: float        # scalar


def compute_dm_features(
    connectivity: np.ndarray,
    reference_connectivity: np.ndarray,
    dm_dim: int = 10,
    sparsity: float = 0.90,
    n_components: int = 30,
    random_state: int = 0,
    align_n_iter: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    compute aligned diffusion-map gradients for each subject FC

    params
    connectivity:
       (S, N, N)
    reference_connectivity:
    dm_dim:
        aligned gradient dimensions to return
    sparsity:
        passed to brainspace GradientMaps (Higher -> more sparse.)
    n_components:
        number of diffusion components computed
    random_state:
        seed for GradientMaps
    align_n_iter:
        iterations for generalized Procrustes alignment

    rtns
    dm_features:
        (S, N, dm_dim) of aligned gradients
    dm_range:
        (S, dm_dim) with per-subject gradient range (max-min)
        for each aligned dimension.
    """

    conn = _as_3d_connectomes(connectivity)
    ref = np.asarray(reference_connectivity, dtype=float)
    if ref.ndim != 2 or ref.shape[0] != ref.shape[1]:
        raise ValueError(f"`reference_connectivity` must be (N, N); got {ref.shape}")
    if ref.shape[0] != conn.shape[1]:
        raise ValueError(
            f"reference N={ref.shape[0]} does not match connectivity N={conn.shape[1]}"
        )
    if dm_dim <= 0:
        raise ValueError("`dm_dim` must be positive.")
    if n_components < dm_dim:
        n_components = dm_dim

    try:
        from brainspace.gradient import GradientMaps  # type: ignore
    except Exception as e:
        raise ImportError(
            "Missing dependency `brainspace`. Install brainspace to compute diffusion-map gradients."
        ) from e

    try:
        from procrustes import generalized  # type: ignore
    except Exception as e:
        raise ImportError(
            "Missing dependency `procrustes`. Install `procrustes` to align gradients."
        ) from e

    # fit template gradients on the reference matrix
    G_ref = GradientMaps(
        n_components=int(n_components),
        approach="dm",
        kernel="normalized_angle",
        random_state=int(random_state),
    )
    G_ref = G_ref.fit(ref, sparsity=float(sparsity))
    template = np.asarray(G_ref.gradients_, dtype=float)  # (N, n_components)

    # fit gradients for each subject
    mats_list: List[np.ndarray] = [np.asarray(conn[i], dtype=float) for i in range(conn.shape[0])]
    G_subj = GradientMaps(
        n_components=int(n_components),
        approach="dm",
        kernel="normalized_angle",
        random_state=int(random_state),
    )
    G_subj = G_subj.fit(mats_list, sparsity=float(sparsity))

    # align each subject gradients to the template
    aligned, _ = generalized(G_subj.gradients_, template, n_iter=int(align_n_iter))
    aligned = np.asarray(aligned, dtype=float)  # (S, N, n_components)

    dm = aligned[:, :, : int(dm_dim)].astype(np.float32)
    dm_range = (dm.max(axis=1) - dm.min(axis=1)).astype(np.float32)  # (S, dm_dim)
    return dm, dm_range


def normalize_dm_embeddings(
    dm_train: np.ndarray, dm_test: np.ndarray, eps: float = 1e-8
) -> Tuple[np.ndarray, np.ndarray, DMParams]:
    """
    normalize diffusion-map embeddings using train-only statistics
    """
    dm_train = np.asarray(dm_train, dtype=float)
    dm_test = np.asarray(dm_test, dtype=float)
    if dm_train.ndim != 3 or dm_test.ndim != 3:
        raise ValueError("dm_train and dm_test must be (B, N, D).")
    m = dm_train.mean(axis=(0, 1), keepdims=False)  # (D,)
    dm_train_c = dm_train - m[None, None, :]
    dm_test_c = dm_test - m[None, None, :]
    s = float(np.std(dm_train_c))
    if (not np.isfinite(s)) or (s < eps):
        s = 1.0
    dm_train_norm = (dm_train_c / (s + eps)).astype(np.float32)
    dm_test_norm = (dm_test_c / (s + eps)).astype(np.float32)
    return dm_train_norm, dm_test_norm, DMParams(dm_global_mean=m.astype(np.float32), dm_global_std=float(s))


# graph-theoretic metrics (graph-level)

GRAPH_FEATURE_NAMES: Tuple[str, ...] = (
    "Eglob_b",          # mean inverse distance (binary)
    "CPL_b",            # mean finite shortest path length (binary)
    "Clust_b",          # mean clustering coefficient (binary)
    "Q_free_b",         # Louvain modularity (binary)
    "n_modules_free",   # number of modules from Louvain partition
    "SW_phi_b",         # small-world phi (binary)
    "SW_sigma_b",       # small-world sigma (binary)
    "MeanStrength_pre", # mean node strength before threshold (positive weights only)
)


def threshold_to_density_positive(W: np.ndarray, density: float) -> np.ndarray:
    """
    threshold a weighted matrix to a target density using positive edges only

    - symmetrizes
    - zero diagonal
    - discards negative weights
    - keeps the top-m positive edges by weight (m determined by density)
    """
    if not (0.0 < float(density) <= 1.0):
        raise ValueError(f"`density` must be in (0, 1]; got {density}")
    W = _symmetrize_zero_diag_positive(W)

    n = W.shape[0]
    m = int(np.floor(float(density) * n * (n - 1) / 2))
    iu = np.triu_indices(n, 1)
    vals = W[iu]
    pos_mask = vals > 0
    pos_vals = vals[pos_mask]

    if pos_vals.size <= m:
        W_thr = np.zeros_like(W)
        W_thr[iu] = vals
        W_thr = W_thr + W_thr.T
        return W_thr

    idx = np.argpartition(pos_vals, -m)[-m:]
    keep_mask = np.zeros_like(pos_mask)
    keep_mask[np.where(pos_mask)[0][idx]] = True

    W_thr = np.zeros_like(W)
    tmp = np.zeros_like(vals)
    tmp[keep_mask] = vals[keep_mask]
    W_thr[iu] = tmp
    W_thr = W_thr + W_thr.T
    return W_thr


def _dist_bin(G: np.ndarray) -> np.ndarray:
    """compatibility wrapper for bct.distance_bin across versions"""
    import bct  
    D = bct.distance_bin(G)
    return D[0] if isinstance(D, tuple) else D


def _path_metrics_from_D(D: np.ndarray) -> Tuple[float, float]:
    """
    given a shortest path distance matrix D (inf for disconnected), compute:
      - CPL: mean finite distance over off-diagonal pairs
      - Eglob: mean inverse distance over finite off-diagonal pairs
    """
    n = D.shape[0]
    mask = np.isfinite(D) & (~np.eye(n, dtype=bool))
    if mask.sum() == 0:
        return float("nan"), float("nan")
    L = float(D[mask].mean())
    with np.errstate(divide="ignore", invalid="ignore"):
        invD = 1.0 / D[mask]
    invD = invD[np.isfinite(invD)]
    E = float(invD.mean()) if invD.size else float("nan")
    return L, E


def _C_and_L_harmonic_bin(A: np.ndarray) -> Tuple[float, float]:
    """
    mean clustering coefficient and harmonic characteristic path length (binary).
    """
    import bct  
    C = float(bct.clustering_coef_bu(A).mean())
    D = _dist_bin(A)

    n = D.shape[0]
    with np.errstate(divide="ignore", invalid="ignore"):
        invD = 1.0 / D
    invD[~np.isfinite(invD)] = 0.0
    np.fill_diagonal(invD, 0.0)
    s = float(invD.sum())
    L_harm = (n * (n - 1)) / s if s > 0 else float("inf")
    return C, float(L_harm)


def _degree_preserving_random(A: np.ndarray, n_iter: int = 10) -> np.ndarray:
    import bct  
    R, _ = bct.randmio_und(A.copy(), n_iter)
    return R


def _lattice_like(A: np.ndarray, n_iter: int = 10) -> np.ndarray:
    """
    attempt bct.latmio_und; fall back to a ring-lattice with matched edge count
    """
    import bct  
    try:
        L, _ = bct.latmio_und(A.copy(), n_iter)
        return L
    except Exception:
        n = A.shape[0]
        E = int(np.count_nonzero(np.triu(A, 1)))
        L = np.zeros((n, n), dtype=int)
        s, added = 1, 0
        while added + n <= E and s <= n // 2:
            for i in range(n):
                j = (i + s) % n
                if L[i, j] == 0:
                    L[i, j] = L[j, i] = 1
            added += n
            s += 1
        rem = E - added
        if rem > 0 and s <= n // 2:
            for i in range(rem):
                j = (i + s) % n
                if L[i, j] == 0:
                    L[i, j] = L[j, i] = 1
        np.fill_diagonal(L, 0)
        return L


def small_world_phi_sigma_binary(
    A: np.ndarray,
    n_rand: int = 10,
    n_latt: int = 10,
    min_denom: float = 1e-8,
) -> Tuple[float, float]:
    """
    compute small-world phi and sigma for a binary graph A
    """
    A = np.asarray(A, dtype=int)
    C_obs, L_obs = _C_and_L_harmonic_bin(A)

    Cr, Lr = [], []
    for _ in range(int(n_rand)):
        Ar = _degree_preserving_random(A, n_iter=10)
        c, l = _C_and_L_harmonic_bin(Ar)
        Cr.append(c)
        Lr.append(l)
    C_rand = float(np.mean(Cr)) if Cr else float("nan")
    L_rand = float(np.mean(Lr)) if Lr else float("nan")

    Cl, Ll = [], []
    for _ in range(int(n_latt)):
        Al = _lattice_like(A, n_iter=10)
        c, l = _C_and_L_harmonic_bin(Al)
        Cl.append(c)
        Ll.append(l)
    C_latt = float(np.mean(Cl)) if Cl else float("nan")
    L_latt = float(np.mean(Ll)) if Ll else float("nan")

    denom_C = max(abs(C_latt - C_rand), float(min_denom))
    denom_L = max(abs(L_latt - L_rand), float(min_denom))
    dC = np.clip((C_latt - C_obs) / denom_C, 0.0, 1.0)
    dL = np.clip((L_obs - L_rand) / denom_L, 0.0, 1.0)
    phi = 1.0 - float(((dC ** 2 + dL ** 2) / 2.0) ** 0.5)

    C_rand_safe = max(C_rand, float(min_denom))
    L_rand_safe = max(L_rand, float(min_denom))
    sigma = (C_obs / C_rand_safe) / (L_obs / L_rand_safe)
    return float(phi), float(sigma)


def compute_graph_metrics(
    connectivity: np.ndarray,
    density: float = 0.20,
    louvain_gamma: float = 1.0,
    sw_n_random: int = 10,
    sw_n_lattice: int = 10,
) -> Tuple[np.ndarray, Tuple[str, ...]]:
    """
    compute graph-level metrics for each subject connectome

    params
    connectivity:
        (S, N, N).
    density:
        target density for thresholding using positive edges only
    louvain_gamma:
        resolution parameter for bct.community_louvain.
    sw_n_random / sw_n_lattice:
        number of random/lattice surrogates for small-world metrics

    rtns
    graph_metrics:
        float32 array (S, n_features) following GRAPH_FEATURE_NAMES.
    feature_names:
        tuple of feature names.
    """
    conn = _as_3d_connectomes(connectivity)

    try:
        import bct  
    except Exception as e:
        raise ImportError("Missing dependency `bct` (bctpy). Install bctpy to compute graph metrics.") from e

    S = conn.shape[0]
    feats = np.full((S, len(GRAPH_FEATURE_NAMES)), np.nan, dtype=np.float32)

    for i in range(S):
        W = _symmetrize_zero_diag_positive(conn[i])
        mean_strength_pre = float(W.sum(axis=1).mean())

        W_thr = threshold_to_density_positive(W, density=float(density))
        G = (W_thr > 0).astype(int)

        D = _dist_bin(G)
        cpl_b, eglob_b = _path_metrics_from_D(D)

        clust_b = float(bct.clustering_coef_bu(G).mean())

        Ci_free, Q_free = bct.community_louvain(G.astype(float), gamma=float(louvain_gamma))
        n_modules = int(len(np.unique(Ci_free))) if Ci_free is not None else 0

        phi, sigma = small_world_phi_sigma_binary(
            G, n_rand=int(sw_n_random), n_latt=int(sw_n_lattice)
        )

        feats[i, :] = np.array(
            [
                eglob_b, cpl_b, clust_b,
                float(Q_free), float(n_modules),
                phi, sigma, mean_strength_pre,
            ],
            dtype=np.float32,
        )

    return feats, GRAPH_FEATURE_NAMES


def compute_connectome_features(
    connectivity: np.ndarray,
    reference_connectivity: np.ndarray,
    *,
    density: float = 0.20,
    dm_dim: int = 10,
    dm_sparsity: float = 0.90,
    dm_n_components: int = 30,
    dm_random_state: int = 0,
    align_n_iter: int = 1,
    louvain_gamma: float = 1.0,
    sw_n_random: int = 10,
    sw_n_lattice: int = 5,
) -> Dict[str, Any]:
    """
    compute (dm_features, dm_range, graph_metrics) in one call.
    """
    dm_features, dm_range = compute_dm_features(
        connectivity=connectivity,
        reference_connectivity=reference_connectivity,
        dm_dim=dm_dim,
        sparsity=dm_sparsity,
        n_components=dm_n_components,
        random_state=dm_random_state,
        align_n_iter=align_n_iter,
    )
    graph_metrics, graph_feature_names = compute_graph_metrics(
        connectivity=connectivity,
        density=density,
        louvain_gamma=louvain_gamma,
        sw_n_random=sw_n_random,
        sw_n_lattice=sw_n_lattice,
    )
    return {
        "dm_features": dm_features,                 # (S, N, dm_dim)
        "dm_range": dm_range,                       # (S, dm_dim)
        "graph_metrics": graph_metrics,             # (S, n_graph_features)
        "graph_feature_names": graph_feature_names, # tuple[str]
    }
