# stage.py
"""Spectral Tangent Alignment & Geometric Embedding (STAGE)
===========================================================
Clean reference implementation with two interchangeable embedding back‑ends:

* **"laplacian"**  (default) – solves the sparse normal equation `Ly = Bᵀw` as in
  Algorithm 1 of the paper.
* **"linreg"**     – reproduces the original *spectral_lin_reg_knn* variant that
  fits an over‑determined linear system with per‑vertex offsets.

Switch via `embedding="linreg"` when calling `stage_embedding`.

SciPy ≥ 1.14 is supported (uses `rtol`/`atol`); older versions fall back to the
legacy `tol` keyword automatically.
"""
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike
from scipy import sparse
from scipy.sparse import linalg as spla
from sklearn.neighbors import BallTree
from scipy.stats import rankdata
from scipy.sparse.linalg import svds
# -----------------------------------------------------------------------------
# Public API ------------------------------------------------------------------
# -----------------------------------------------------------------------------
__all__ = [
    "stage_embedding",
    "evaluate_kendall",
    "StageResult",
]

# -----------------------------------------------------------------------------
# Result bundle ---------------------------------------------------------------
# -----------------------------------------------------------------------------
@dataclass
class StageResult:
    y: np.ndarray
    order: np.ndarray
    tangents: Optional[np.ndarray] = None
    graph_edges: Optional[np.ndarray] = None
    signs: Optional[np.ndarray] = None

# -----------------------------------------------------------------------------
# Embedding helpers -----------------------------------------------------------
# -----------------------------------------------------------------------------

def _build_incidence(edges: np.ndarray, n: int) -> sparse.csr_matrix:
    r, c = edges[:, 0], edges[:, 1]
    m = edges.shape[0]
    data = np.tile([1.0, -1.0], m)
    row_idx = np.repeat(np.arange(m), 2)
    col_idx = np.concatenate([r, c])
    return sparse.csr_matrix((data, (row_idx, col_idx)), shape=(m, n))


def _embed_laplacian(X: np.ndarray, tangents: np.ndarray, edges: np.ndarray,
                     rtol: float) -> np.ndarray:
    """Normal‑equation solve  Ly = Bᵀw (Algorithm 1)."""
    n = X.shape[0]
    B = _build_incidence(edges, n)
    r, c = edges[:, 0], edges[:, 1]
    w = 0.5 * np.einsum("ij,ij->i", tangents[r] + tangents[c], X[c] - X[r])

    def mv(z):
        return B.T @ (B @ z)

    L = spla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
    try:
        y, info = spla.cg(L, B.T @ w, rtol=rtol, atol=0.0)
    except TypeError:  # SciPy < 1.14
        y, info = spla.cg(L, B.T @ w, tol=rtol)
    if info != 0:
        raise RuntimeError(f"CG did not converge (info={info}).")
    return y - y.mean()


def _embed_linreg(X: np.ndarray, tangents: np.ndarray, edges: np.ndarray,
                  rtol: float) -> np.ndarray:
    """Over‑determined linear regression variant (original code path)."""
    n = X.shape[0]
    # Build adjacency list
    adj = [[] for _ in range(n)]
    for i, j in edges:
        adj[i].append(j)
        adj[j].append(i)

    rows, cols, data = [], [], []
    t_vec = []
    row_id = 0
    for i, neigh in enumerate(adj):
        if not neigh:
            continue
        mu_i = X[neigh].mean(axis=0)
        v_i = tangents[i]
        for j in neigh:
            rows.extend([row_id, row_id])
            cols.extend([j, n + i])  # y_j minus c_i
            data.extend([1.0, -1.0])
            t_vec.append(np.dot(X[j] - mu_i, v_i))
            row_id += 1

    m = row_id
    A = sparse.csr_matrix((data, (rows, cols)), shape=(m, 2 * n))
    t_vec = np.asarray(t_vec, dtype=np.float64)

    res = spla.lsqr(A, t_vec, atol=0.0, btol=rtol)
    y = res[0][:n]
    return y - y.mean()


EMBED_METHODS = {
    "laplacian": _embed_laplacian,
    "linreg": _embed_linreg,
}

# -----------------------------------------------------------------------------
# Core algorithm --------------------------------------------------------------
# -----------------------------------------------------------------------------

def stage_embedding(
    X: ArrayLike,
    *,
    r: Optional[float] = None,
    k: Optional[int] = None,
    rtol: float = 1e-6,
    neighbour_min: int = 6,
    pca_full_dim: bool = False,
    embedding: str = "laplacian",
    return_intermediates: bool = False,
) -> StageResult | Tuple[np.ndarray, np.ndarray]:
    """Run STAGE on *X* and return (y, order) or a :class:`StageResult`.

    Parameters
    ----------
    r, k : Graph parameters (exactly one must be provided).
    embedding : "laplacian" or "linreg".
    """
    X = np.asarray(X, dtype=np.float64, order="C")
    if X.ndim != 2:
        raise ValueError("X must be 2‑D array (n_samples, n_features)")
    if (r is None) == (k is None):
        raise ValueError("Specify exactly one of *r* or *k*.")

    # --- 1. Neighbourhood graph ---------------------------------------------
    tree = BallTree(X)
    if r is not None:
        neigh_ind = tree.query_radius(X, r)
    else:
        neigh_ind = tree.query(X, k=k + 1, return_distance=False)

    rows, cols = [], []
    for i, neigh in enumerate(neigh_ind):
        neigh = neigh[neigh != i]
        rows.append(np.full(neigh.size, i, np.int32))
        cols.append(neigh.astype(np.int32))
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    mask = rows < cols
    edges = np.stack((rows[mask], cols[mask]), axis=1)

    # --- 2. Local PCA (tangent estimation) -----------------------------------
    tangents = _estimate_local_tangents(X, edges, neighbour_min, pca_full_dim)

    # --- 3. Z₂ synchronisation ----------------------------------------------
    tangents_oriented, signs = _z2_sync(tangents, edges)

    # --- 4. Global embedding --------------------------------------------------
    if embedding not in EMBED_METHODS:
        raise ValueError(f"embedding must be one of {list(EMBED_METHODS)}")
    y = EMBED_METHODS[embedding](X, tangents_oriented, edges, rtol)
    # order = np.argsort(y)
    order = rankdata(y, method='average')

    if return_intermediates:
        return StageResult(y=y, order=order, tangents=tangents_oriented,
                           graph_edges=edges, signs=signs)
    return y, order

# -----------------------------------------------------------------------------
# 1. Local PCA ----------------------------------------------------------------
# -----------------------------------------------------------------------------

def _estimate_local_tangents(
    X: np.ndarray,
    edges: np.ndarray,
    neighbour_min: int,
    pca_full_dim: bool,
) -> np.ndarray:
    n, d = X.shape
    adj = [[] for _ in range(n)]
    for i, j in edges:
        adj[i].append(j)
        adj[j].append(i)

    tangents = np.empty((n, d))
    rng = np.random.default_rng()

    for i, neigh in enumerate(adj):
        if len(neigh) < 2:                     # 0 or 1 neighbour → cannot fit a line
            tangents[i] = np.zeros(d)          # mark as bad – will be re‑oriented later
            continue                           # (or copy a random unit vector)
        # if len(neigh) < neighbour_min:
        #     raise RuntimeError(
        #         f"Point {i} has only {len(neigh)} neighbours < {neighbour_min}.")
        Y = X[neigh] - X[neigh].mean(axis=0, keepdims=True)
        m = Y.shape[0]
        if pca_full_dim or m < d:
            # Full SVD fallback
            _, _, vt_top = svds(Y, k=1, which='LM', tol=1e-4) # tol can be adjusted
            tangents[i] = vt_top[0] # vt_top will have shape (1, d)
        else:
            # 3‑step power iteration
            v = rng.standard_normal(d)
            v /= np.linalg.norm(v)
            for _ in range(3):
                v = Y.T @ (Y @ v)
                v /= np.linalg.norm(v)
            tangents[i] = v
    return tangents

# -----------------------------------------------------------------------------
# 2. Z₂ synchronisation -------------------------------------------------------
# -----------------------------------------------------------------------------

def _z2_sync(tangents: np.ndarray, edges: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    n = tangents.shape[0]
    r, c = edges[:, 0], edges[:, 1]
    dot = np.einsum("ij,ij->i", tangents[r], tangents[c])
    Q = sparse.csr_matrix((np.concatenate([dot, dot]),
                           (np.concatenate([r, c]), np.concatenate([c, r]))), shape=(n, n))
    eigval, eigvec = spla.eigsh(Q, k=1, which="LA")
    signs = np.sign(eigvec[:, 0])
    signs[signs == 0] = 1.0
    return tangents * signs[:, None], signs

# -----------------------------------------------------------------------------
# Evaluation ------------------------------------------------------------------
# -----------------------------------------------------------------------------
def evaluate_kendall_abs(order: Sequence[int], truth: Sequence[int]) -> float:
    from scipy.stats import kendalltau
    tau, _ = kendalltau(order, truth)
    return abs(tau)

def evaluate_kendall(order: Sequence[int], truth: Sequence[int]) -> float:
    from scipy.stats import kendalltau
    tau, _ = kendalltau(order, truth)
    return tau
