from __future__ import annotations

import re
from itertools import combinations
from typing import TYPE_CHECKING

import numpy as np
from scipy.linalg import ldl

from pepflow import expression_manager as exm
from pepflow import ipython_utils
from pepflow import pep_context as pc
from pepflow import utils
from pepflow.pep_result import MatrixWithNames

if TYPE_CHECKING:
    from pepflow.vector import Vector

# TODO: Should add unit tests.


def ldl_deompose_with_reversed_basis(
    S: MatrixWithNames, basis: list[Vector], print_output: bool = True
) -> tuple[MatrixWithNames, np.ndarray, list[Vector]]:
    """Run LDL on a reversed-basis matrix and return labeled factors.

    This routine assumes the input basis and matrix labels follow the same
    "earlier iterates first" convention. It reverses both orders, runs LDL,
    and applies LDL permutation to labels/basis when needed.

    Args:
        S: Named symmetric matrix to decompose.
        basis: Basis vectors aligned with ``S.row_names``.
        print_output: Whether to pretty-print ``L^T`` (labeled) and ``D``.

    Raises:
        ValueError: If reversed basis labels do not match reversed matrix labels.

    Returns:
        A tuple ``(LT_named, d, ell)`` where:
        ``LT_named`` is the labeled transpose of the LDL lower factor,
        ``d`` is the diagonal/block-diagonal matrix from LDL,
        and ``ell`` are the linear forms built from rows of ``LT_named``.
    """
    # Assuming the original basis is ordered by gradients at earlier iterates first,
    # reverse the basis order so gradients at later iterates appear first.
    A = S.matrix[::-1, :][:, ::-1]
    names_rev = S.row_names[::-1]
    reversed_basis = basis[::-1]

    basis_names_rev = [repr(v) for v in reversed_basis]
    if names_rev != basis_names_rev:
        raise ValueError(
            "Reversed matrix labels and reversed basis labels do not match. "
            f"names_rev={names_rev}, basis_names_rev={basis_names_rev}"
        )

    lu, d, perm = ldl(A)
    perm = np.asarray(perm, dtype=int)
    if not np.array_equal(perm, np.arange(len(perm))):
        lu = lu[perm, :]
        names_rev = [names_rev[i] for i in perm]
        reversed_basis = [reversed_basis[i] for i in perm]

    LT = lu.T
    n = lu.shape[0]
    row_labels = [rf"\ell_{{{i + 1}}}" for i in range(n)]
    col_labels = names_rev

    LT_named = MatrixWithNames(
        matrix=LT,
        row_names=row_labels,
        col_names=col_labels,
    )

    if print_output:
        LT_named.pprint()
        ipython_utils.pprint_matrix(d)

    ell = [LT[i, :].T @ reversed_basis for i in range(LT.shape[0])]

    return LT_named, d, ell


def decompose_with_labels(
    S: MatrixWithNames, basis: list[Vector], print_output: bool = True
) -> tuple[MatrixWithNames, np.ndarray, list[Vector]]:
    """Backward-compatible alias for ``ldl_deompose_with_reversed_basis``."""
    return ldl_deompose_with_reversed_basis(S=S, basis=basis, print_output=print_output)


def ldl_reversed_basis_with_labels(
    S: MatrixWithNames, basis: list[Vector], print_output: bool = True
) -> tuple[MatrixWithNames, np.ndarray, list[Vector]]:
    """Backward-compatible alias for ``ldl_deompose_with_reversed_basis``."""
    return ldl_deompose_with_reversed_basis(S=S, basis=basis, print_output=print_output)


def vectors_in_column_space(
    V: np.ndarray,
    vectors: list[Vector],
    pep_context: pc.PEPContext | None = None,
    *,
    resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
    rtol: float = 1e-7,
    atol: float = 1e-7,
) -> list[Vector]:
    """Collect vectors that lie in the column space of ``V``.

    Args:
        V: Matrix representation of the quadratic form.
        vectors: Candidate vectors to test for membership in ``col(V)``.
        pep_context (:class:`PEPContext` | None): The :class:`PEPContext` object
            we consider. `None` if we consider the current global
            :class:`PEPContext` object.
        resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`] | None): A
            dictionary that maps the name of parameters to the numerical values.
        rtol: Relative tolerance for numerical checks.
        atol: Absolute tolerance for numerical checks.

    Returns:
        Vectors that lie in ``col(V)`` within the given tolerances.
    """
    if pep_context is None:
        pep_context = pc.get_current_context()
    if pep_context is None:
        raise RuntimeError("Did you forget to create a context?")
    pm = exm.ExpressionManager(pep_context, resolve_parameters=resolve_parameters)
    V_coords = np.asarray(pm.eval_scalar(V).inner_prod_coords, dtype=float)

    # SVD of A
    U, S, _ = np.linalg.svd(V_coords, full_matrices=False)

    # numerical rank
    tol = atol + rtol * S[0]
    rank = np.sum(S > tol)

    # basis for column space
    Uc = U[:, :rank]

    results = []
    for v in vectors:
        v_coords = np.asarray(pm.eval_vector(v).coords, dtype=float)
        proj = Uc @ (Uc.T @ v_coords)
        residual = np.linalg.norm(v_coords - proj)
        if residual <= atol + rtol * np.linalg.norm(v_coords):
            results.append(v)
    return results


def independent_subset(
    vecs: list[Vector],
    pep_context: pc.PEPContext | None = None,
    *,
    resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
    tol: float = 1e-7,
    sort_vectors: bool = False,
) -> tuple[list[Vector], list[int]]:
    """Select a linearly independent subset of vectors using greedy Gram-Schmidt.

    Args:
        vecs: Candidate vectors.
        pep_context (:class:`PEPContext` | None): The :class:`PEPContext` object
            we consider. `None` if we consider the current global
            :class:`PEPContext` object.
        resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`] | None): A
            dictionary that maps the name of parameters to the numerical values.
        tol: Relative tolerance for independence checks.
        sort_vectors: If ``True``, apply deterministic natural sorting based on
            ``str(v)`` before greedy selection.

    Returns:
        Chosen vectors and their indices in ``vecs``.
    """
    if pep_context is None:
        pep_context = pc.get_current_context()
    if pep_context is None:
        raise RuntimeError("Did you forget to create a context?")
    pm = exm.ExpressionManager(pep_context, resolve_parameters=resolve_parameters)

    Q = []
    chosen = []
    idx = []

    indexed_vecs = list(enumerate(vecs))
    if sort_vectors:
        indexed_vecs.sort(key=lambda iv: _natural_sort_key(str(iv[1])))

    for i, v in indexed_vecs:
        v_coords = np.asarray(pm.eval_vector(v).coords, dtype=float)
        nv = np.linalg.norm(v_coords)
        if nv == 0:
            continue

        r = v_coords.copy()
        for q in Q:
            r -= q * (q @ r)

        if np.linalg.norm(r) > tol * nv:
            q = r / np.linalg.norm(r)
            Q.append(q)
            chosen.append(vecs[i])
            idx.append(i)

    return chosen, idx


def _natural_sort_key(text: str) -> list[int | str]:
    """Natural-sort key, e.g. x_2 < x_10."""
    return [int(tok) if tok.isdigit() else tok for tok in re.split(r"(\d+)", text)]


def sorted_independent_subset_for_print(
    subset: tuple[list[Vector], list[int]],
) -> tuple[list[Vector], list[int]]:
    """Sort an independent-subset result for display only.

    This does not recompute independence; it only reorders the returned pair
    by vector string in natural order.
    """
    chosen, idx = subset
    if len(chosen) != len(idx):
        raise ValueError("`chosen` and `idx` must have the same length.")

    ordered = sorted(zip(idx, chosen), key=lambda pair: _natural_sort_key(str(pair[1])))
    return [v for _, v in ordered], [i for i, _ in ordered]


def decompose_rankr_symmetric(
    S,
    vecs: list[Vector],
    pep_context: pc.PEPContext | None = None,
    *,
    resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
    sym_tol: float = 1e-10,
    indep_tol: float = 1e-7,
) -> np.ndarray:
    """Decompose a symmetric matrix over pairwise symmetric vector products.

    Finds ``C[i, j]`` such that:
    ``S ~= sum_{i<=j} C[i,j] * A_{ij}``
    where ``A_{ii} = v_i v_i^T`` and
    ``A_{ij} = v_i v_j^T + v_j v_i^T`` for ``i < j``.

    Args:
        S: Symmetric matrix (or scalar expression with ``inner_prod_coords``).
        vecs: Vector list used to build the decomposition basis.
        pep_context (:class:`PEPContext` | None): The :class:`PEPContext` object
            we consider. `None` if we consider the current global
            :class:`PEPContext` object.
        resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`] | None): A
            dictionary that maps the name of parameters to the numerical values.
        sym_tol: Tolerance for symmetry validation of ``S``.
        indep_tol: Tolerance for linear-independence rank checks of ``vecs``.

    Raises:
        ValueError: If shape checks fail, ``S`` is not symmetric, or ``vecs`` are dependent.

    Returns:
        Symmetric coefficient matrix ``C``.
    """
    if pep_context is None:
        pep_context = pc.get_current_context()
    if pep_context is None:
        raise RuntimeError("Did you forget to create a context?")
    pm = exm.ExpressionManager(pep_context, resolve_parameters=resolve_parameters)

    S_coords = np.asarray(pm.eval_scalar(S).inner_prod_coords, dtype=float)
    V = [np.asarray(pm.eval_vector(u).coords, dtype=float).ravel() for u in vecs]

    n = S_coords.shape[0]
    if S_coords.shape != (n, n):
        raise ValueError("S must be square.")
    if any(v.shape != (n,) for v in V):
        raise ValueError("All vectors must have shape (n,).")

    if np.linalg.norm(S_coords - S_coords.T, ord="fro") > sym_tol * max(
        1.0, np.linalg.norm(S_coords, ord="fro")
    ):
        raise ValueError("S is not symmetric within tolerance.")

    m = len(V)
    W = np.stack(V, axis=1)
    rank = np.linalg.matrix_rank(W, tol=indep_tol)
    if rank < m:
        raise ValueError(
            f"vecs are linearly dependent: {m} vectors but rank is {rank}."
        )

    cols = []
    pairs = []
    for i in range(m):
        for j in range(i, m):
            A = np.outer(V[i], V[j])
            if i != j:
                A = A + A.T
            cols.append(A.reshape(-1))
            pairs.append((i, j))

    X = np.stack(cols, axis=1) if cols else np.zeros((n * n, 0))
    y = S_coords.reshape(-1)

    coeffs, *_ = np.linalg.lstsq(X, y, rcond=None)

    C = np.zeros((m, m))
    for (i, j), c in zip(pairs, coeffs):
        C[i, j] = c
        C[j, i] = c
    return C


def _decompose_rankr_symmetric_from_coords(
    S_coords: np.ndarray,
    V: list[np.ndarray],
    *,
    sym_tol: float = 1e-10,
    indep_tol: float = 1e-7,
) -> np.ndarray:
    """Numerical backend for symmetric decomposition from pre-evaluated coords."""
    n = S_coords.shape[0]
    if S_coords.shape != (n, n):
        raise ValueError("S must be square.")
    if any(v.shape != (n,) for v in V):
        raise ValueError("All vectors must have shape (n,).")

    if np.linalg.norm(S_coords - S_coords.T, ord="fro") > sym_tol * max(
        1.0, np.linalg.norm(S_coords, ord="fro")
    ):
        raise ValueError("S is not symmetric within tolerance.")

    m = len(V)
    if m == 0:
        return np.zeros((0, 0))

    W = np.stack(V, axis=1)
    rank = np.linalg.matrix_rank(W, tol=indep_tol)
    if rank < m:
        raise ValueError(
            f"vecs are linearly dependent: {m} vectors but rank is {rank}."
        )

    cols = []
    pairs = []
    for i in range(m):
        for j in range(i, m):
            A = np.outer(V[i], V[j])
            if i != j:
                A = A + A.T
            cols.append(A.reshape(-1))
            pairs.append((i, j))

    X = np.stack(cols, axis=1)
    y = S_coords.reshape(-1)
    coeffs, *_ = np.linalg.lstsq(X, y, rcond=None)

    C = np.zeros((m, m))
    for (i, j), c in zip(pairs, coeffs):
        C[i, j] = c
        C[j, i] = c
    return C


def find_sparsest_decompositions(
    S,
    vecs: list[Vector],
    *,
    pep_context: pc.PEPContext | None = None,
    resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
    fixed_vectors: list[Vector] | None = None,
    zero_tol: float = 1e-6,
    indep_tol: float = 1e-7,
) -> tuple[list[Vector], np.ndarray]:
    """Find independent subsets yielding the sparsest symmetric decomposition.

    Among linearly independent subsets of size ``rank(vecs)``, this function finds those
    whose decomposition coefficient matrix has the maximum number of near-zero entries.

    Args:
        S: Symmetric matrix (or scalar expression with ``inner_prod_coords``).
        vecs: Candidate vectors used for subset search.
        pep_context (:class:`PEPContext` | None): The :class:`PEPContext` object
            we consider. `None` if we consider the current global
            :class:`PEPContext` object.
        resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`] | None): A
            dictionary that maps the name of parameters to the numerical values.
        fixed_vectors: Vectors that must be included in every searched subset.
            Each vector must be present in ``vecs``.
        zero_tol: Absolute threshold for counting near-zero coefficients.
        indep_tol: Tolerance for linear-independence rank checks.

    Returns:
        A tuple ``(best_vectors, best_coefficients)`` where
        ``best_vectors`` is one best subset and
        ``best_coefficients`` is its coefficient matrix.
    """
    if pep_context is None:
        pep_context = pc.get_current_context()
    if pep_context is None:
        raise RuntimeError("Did you forget to create a context?")
    pm = exm.ExpressionManager(pep_context, resolve_parameters=resolve_parameters)

    if not vecs:
        return [], np.zeros((0, 0))

    S_coords = np.asarray(pm.eval_scalar(S).inner_prod_coords, dtype=float)
    V_all = [np.asarray(pm.eval_vector(u).coords, dtype=float).ravel() for u in vecs]
    r = np.linalg.matrix_rank(np.stack(V_all, axis=1), tol=indep_tol)
    fixed_vectors = fixed_vectors or []

    fixed_idx: list[int] = []
    if fixed_vectors:
        idx_by_id = {id(v): i for i, v in enumerate(vecs)}
        missing = [v for v in fixed_vectors if id(v) not in idx_by_id]
        if missing:
            missing_s = ", ".join(str(v) for v in missing)
            raise ValueError(
                f"`fixed_vectors` must be contained in `vecs`. Missing: {missing_s}"
            )

        seen_idx = set()
        for v in fixed_vectors:
            i = idx_by_id[id(v)]
            if i not in seen_idx:
                seen_idx.add(i)
                fixed_idx.append(i)

    if len(fixed_idx) > r:
        raise ValueError(f"Too many fixed vectors: {len(fixed_idx)} > rank(vecs)={r}.")

    best_vectors: list[Vector] = []
    best_coefficients = np.zeros((0, 0))
    max_zeros = -1

    fixed_idx_set = set(fixed_idx)
    free_idx = [i for i in range(len(vecs)) if i not in fixed_idx_set]
    n_free_pick = r - len(fixed_idx)

    for picked_free in combinations(free_idx, n_free_pick):
        idx = tuple(sorted(fixed_idx + list(picked_free)))
        sub_vecs = [vecs[i] for i in idx]
        sub_V = [V_all[i] for i in idx]
        W = np.stack(sub_V, axis=1)
        if np.linalg.matrix_rank(W, tol=indep_tol) < r:
            continue

        C = _decompose_rankr_symmetric_from_coords(
            S_coords,
            sub_V,
            indep_tol=indep_tol,
        )
        n_zeros = int(np.sum(np.abs(C) < zero_tol))

        if n_zeros > max_zeros:
            max_zeros = n_zeros
            best_vectors = sub_vecs
            best_coefficients = C

    if max_zeros < 0 and fixed_vectors:
        fixed_s = ", ".join(str(v) for v in fixed_vectors)
        raise ValueError(
            f"No feasible independent subset satisfies `fixed_vectors`: {fixed_s}"
        )

    return best_vectors, best_coefficients


def infer_k_dependent_basis_templates(
    lyap: list,
    candidate_vectors: list[Vector],
    pep_context: pc.PEPContext | None = None,
    *,
    resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
    k_start: int = 2,
    zero_tol: float = 1e-6,
    indep_tol: float = 1e-7,
) -> tuple[dict[int, list[Vector]], list[str]]:
    """Infer a k-dependent basis template from sparse decompositions.

    For each ``k`` in ``[k_start, len(lyap) - 1]``, this function:
    1) keeps candidate vectors in ``col(lyap[k])``,
    2) selects one sparsest decomposition basis,
    3) infers a shared string template across ``k`` by replacing only the
       current-k token ``x_{k+1}`` with ``x_{k+1}`` and checking whether
       the resulting strings match.

    Args:
        lyap: Sequence of Lyapunov expressions/matrices.
        candidate_vectors: Candidate basis vectors.
        pep_context (:class:`PEPContext` | None): The :class:`PEPContext` object
            we consider. `None` if we consider the current global
            :class:`PEPContext` object.
        resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`] | None): A
            dictionary that maps the name of parameters to the numerical values.
        k_start: First k index to use for inference.
        zero_tol: Absolute threshold for counting near-zero coefficients.
        indep_tol: Tolerance for linear-independence rank checks.

    Returns:
        A tuple ``(basis_by_k, templates)`` where:
        ``basis_by_k`` maps each k to the inferred sparse basis vectors,
        and ``templates`` is an ordered list of inferred string templates.
    """
    if pep_context is None:
        pep_context = pc.get_current_context()
    if pep_context is None:
        raise RuntimeError("Did you forget to create a context?")

    basis_by_k: dict[int, list[Vector]] = {}
    for k in range(k_start, len(lyap)):
        aligned = vectors_in_column_space(
            lyap[k],
            candidate_vectors,
            pep_context=pep_context,
            resolve_parameters=resolve_parameters,
            rtol=indep_tol,
            atol=indep_tol,
        )
        best_vectors, _ = find_sparsest_decompositions(
            lyap[k],
            aligned,
            pep_context=pep_context,
            resolve_parameters=resolve_parameters,
            zero_tol=zero_tol,
            indep_tol=indep_tol,
        )
        basis_by_k[k] = best_vectors

    if not basis_by_k:
        return {}, []

    ks = sorted(basis_by_k.keys())
    lengths = {len(basis_by_k[k]) for k in ks}
    if len(lengths) != 1:
        raise ValueError("Inferred bases do not have consistent lengths across k.")

    def _normalize_with_k(s: str, k: int) -> str:
        token = f"x_{k + 1}"
        return re.sub(rf"\b{re.escape(token)}\b", "x_{k+1}", s)

    # Normalize per k, then align basis order across k by a reference k.
    normalized_by_k: dict[int, list[str]] = {
        k: [_normalize_with_k(str(v), k) for v in basis_by_k[k]] for k in ks
    }

    # Pick the k whose normalized basis has the largest overlap with others.
    reference_k = max(
        ks,
        key=lambda k0: sum(
            len(set(normalized_by_k[k0]).intersection(set(normalized_by_k[k1])))
            for k1 in ks
            if k1 != k0
        ),
    )
    reference_order = normalized_by_k[reference_k]

    for k in ks:
        pairs = list(zip(normalized_by_k[k], basis_by_k[k]))
        used = [False] * len(pairs)
        reordered_pairs: list[tuple[str, Vector]] = []

        # First, match reference order exactly when possible.
        for ref in reference_order:
            matched_idx = next(
                (
                    i
                    for i, (norm, _) in enumerate(pairs)
                    if (not used[i]) and norm == ref
                ),
                None,
            )
            if matched_idx is not None:
                used[matched_idx] = True
                reordered_pairs.append(pairs[matched_idx])

        # Then append unmatched terms in deterministic order.
        unmatched = [pairs[i] for i in range(len(pairs)) if not used[i]]
        unmatched.sort(key=lambda x: x[0])
        reordered_pairs.extend(unmatched)

        normalized_by_k[k] = [norm for norm, _ in reordered_pairs]
        basis_by_k[k] = [vec for _, vec in reordered_pairs]

    basis_len = lengths.pop()
    templates: list[str] = []
    for i in range(basis_len):
        normalized_seq = [normalized_by_k[k][i] for k in ks]
        if len(set(normalized_seq)) == 1:
            templates.append(normalized_seq[0])
            continue

        seq = [str(basis_by_k[k][i]) for k in ks]
        templates.append(" | ".join(seq))

    # TODO: printing string related to result in source code is not our style. Should modify it.
    # Validate inferred templates across k and report match coverage.
    comparable_positions = [i for i, t in enumerate(templates) if " | " not in t]
    if not comparable_positions:
        print("No pattern was found.")
    else:
        matched_ks: list[int] = []
        for k in ks:
            if all(normalized_by_k[k][i] == templates[i] for i in comparable_positions):
                matched_ks.append(k)

        if len(matched_ks) == len(ks):
            print(f"Inferred rule matches all k indices: {matched_ks}")
        elif len(matched_ks) > 1:
            print(f"Inferred rule matches k indices: {matched_ks}")
        else:
            print("No pattern was found.")

    return basis_by_k, templates
