"""CVXPY SDPs for the one-step gradient-descent example.

The notebook owns the experiment loop, exhaustive sparsification scan,
closed-form certificate, and plots.  This module keeps the SDP algebra out of
the narrative cells.
"""

from __future__ import annotations

from typing import Iterable

import cvxpy as cp
import numpy as np

from .sdp_utils import (
    inner_matrix,
    smooth_strongly_convex_interpolation_matrix,
    square_matrix,
    symmetrize,
    unit,
)


GOOD_STATUSES = {cp.OPTIMAL, cp.OPTIMAL_INACCURATE, "optimal", "optimal_inaccurate"}
ACTIVE_FITTED_CURVATURE_INDICES = (0, 1, 3)
PointData = tuple[np.ndarray, np.ndarray, np.ndarray]
InterpolationPair = tuple[str, PointData, PointData]


def closed_form_rate(gamma: float, L: float = 1.0, mu: float = 0.1) -> float:
    return max((1.0 - gamma * mu) ** 2, (1.0 - gamma * L) ** 2)


def closed_form_multipliers(gamma: float, L: float = 1.0, mu: float = 0.1) -> np.ndarray:
    values = np.zeros(6)
    if gamma <= 2.0 / (L + mu):
        values[0] = gamma * mu * (1.0 - gamma * mu)
        values[1] = gamma * mu
        values[3] = 1.0 - gamma * mu
    else:
        values[0] = (2.0 - gamma * L) * (gamma * L - 1.0)
        values[1] = 2.0 - gamma * L
        values[3] = gamma * L - 1.0
    return values


def fitted_curvature_parameters(gamma: float, L: float = 1.0, mu: float = 0.1) -> tuple[float, float]:
    if gamma <= 2.0 / (L + mu):
        return 2.0 / gamma - mu, mu
    return L, 2.0 / gamma - L


def _gd_interpolation_pairs(gamma: float) -> tuple[InterpolationPair, ...]:
    """Return labeled point pairs in the reduced GD basis ``(x0-xs, g0, g1)``."""

    zero_x = np.zeros(3)
    zero_f = np.zeros(2)
    x0_minus_xs = unit(3, 0)
    g0 = unit(3, 1)
    g1 = unit(3, 2)

    xs = (zero_x, zero_x, zero_f)
    x0 = (x0_minus_xs, g0, np.array([1.0, 0.0]))
    x1 = (x0_minus_xs - gamma * g0, g1, np.array([0.0, 1.0]))

    return (
        ("xs->x0", xs, x0),
        ("xs->x1", xs, x1),
        ("x0->xs", x0, xs),
        ("x0->x1", x0, x1),
        ("x1->xs", x1, xs),
        ("x1->x0", x1, x0),
    )


def _pair_quadratic_parts(
    source: PointData,
    target: PointData,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Return the matrix pieces of a two-point interpolation inequality.

    The article uses ``H >= 0``.  The SDP code stores ``-H``: the value part is
    ``f_j-f_i`` and the matrix has the positive quadratic terms.
    """

    xi, gi, _ = source
    xj, gj, _ = target
    displacement = xi - xj
    grad_delta = gi - gj
    return (
        inner_matrix(gj, displacement),
        square_matrix(displacement),
        inner_matrix(displacement, grad_delta),
        square_matrix(grad_delta),
    )


def build_gd_sdp(
    L: float,
    mu: float,
    gamma: float,
) -> tuple[tuple[str, ...], tuple[np.ndarray, ...], tuple[np.ndarray, ...]]:
    """Return the SDP coefficients for the basis ``(x0-xs, g0, g1)``."""

    if not 0.0 <= mu < L:
        raise ValueError("Expected 0 <= mu < L.")
    if gamma <= 0.0:
        raise ValueError("Expected gamma > 0.")

    interpolation_pairs = _gd_interpolation_pairs(gamma)
    labels = tuple(label for label, _, _ in interpolation_pairs)
    value_vectors = []
    matrices = []
    for _, source, target in interpolation_pairs:
        xi, gi, fi = source
        xj, gj, fj = target
        value_vectors.append(fj - fi)
        matrices.append(smooth_strongly_convex_interpolation_matrix(xi, gi, xj, gj, L=L, mu=mu))

    return labels, tuple(value_vectors), tuple(matrices)


def solve_gd_dual(
    L: float,
    mu: float,
    gamma: float,
    inactive_multipliers: Iterable[int] = (),
    solver: str = "MOSEK",
) -> tuple[str, float | None, np.ndarray | None]:
    """Solve the GD dual SDP, optionally setting interpolation multipliers to zero."""

    labels, value_vectors, matrices = build_gd_sdp(L=L, mu=mu, gamma=gamma)
    inactive = set(inactive_multipliers)
    if any(idx < 0 or idx >= len(value_vectors) for idx in inactive):
        raise ValueError(f"Inactive multiplier indices must be in [0, {len(labels) - 1}].")

    rho = cp.Variable(nonneg=True)
    lambdas = cp.Variable(len(value_vectors), nonneg=True)

    value_balance = rho * np.array([1.0, 0.0]) - np.array([0.0, 1.0])
    matrix_balance = np.zeros((3, 3))
    for idx, (value_i, matrix_i) in enumerate(zip(value_vectors, matrices)):
        value_balance = value_balance + lambdas[idx] * value_i
        matrix_balance = matrix_balance + lambdas[idx] * matrix_i

    constraints: list[cp.Constraint] = [value_balance == 0.0, matrix_balance >> 0.0]
    for idx in sorted(inactive):
        constraints.append(lambdas[idx] == 0.0)

    problem = cp.Problem(cp.Minimize(rho), constraints)
    problem.solve(solver=solver, verbose=False, warm_start=True)

    if problem.status not in GOOD_STATUSES or rho.value is None or lambdas.value is None:
        return str(problem.status), None, None

    multipliers = np.maximum(np.asarray(lambdas.value), 0.0)
    multipliers[np.abs(multipliers) < 1e-12] = 0.0
    return str(problem.status), float(rho.value), multipliers


def identify_interpolation_curvatures(
    matrix: np.ndarray,
    gamma: float,
    pair_index: int,
    coefficient_tol: float = 1e-10,
) -> dict:
    """Identify ``(L, mu)`` from a code-side two-point interpolation matrix."""

    interpolation_pairs = _gd_interpolation_pairs(gamma)
    label, source, target = interpolation_pairs[pair_index]
    base_matrix, xx_matrix, xg_matrix, gg_matrix = _pair_quadratic_parts(source, target)

    residual_matrix = symmetrize(matrix - base_matrix)
    basis = np.column_stack(
        [
            xx_matrix.reshape(-1),
            xg_matrix.reshape(-1),
            gg_matrix.reshape(-1),
        ]
    )
    coeff_xx, coeff_xg, coeff_gg = np.linalg.lstsq(basis, residual_matrix.reshape(-1), rcond=None)[0]
    reconstructed = coeff_xx * xx_matrix + coeff_xg * xg_matrix + coeff_gg * gg_matrix
    reconstruction_error = float(np.max(np.abs(residual_matrix - reconstructed)))

    if coeff_gg <= coefficient_tol:
        L_hat = np.nan
        mu_hat = np.nan
        relation_error = np.nan
    else:
        mu_hat = -coeff_xg / (2.0 * coeff_gg)
        L_hat = mu_hat + 1.0 / (2.0 * coeff_gg)
        relation_error = float(abs(coeff_xx - coeff_gg * L_hat * mu_hat))

    return {
        "label": label,
        "L": float(L_hat),
        "mu": float(mu_hat),
        "coeff_xx": float(coeff_xx),
        "coeff_xg": float(coeff_xg),
        "coeff_gg": float(coeff_gg),
        "reconstruction_error": reconstruction_error,
        "relation_error": relation_error,
    }


def solve_fitted_curvature_candidate_sdp(
    L: float,
    mu: float,
    gamma: float,
    multipliers: np.ndarray,
    target_rate: float | None = None,
    candidate_indices: Iterable[int] = ACTIVE_FITTED_CURVATURE_INDICES,
    solver: str = "MOSEK",
    value_residual_tol: float = 1e-8,
) -> dict:
    """Search for the singleton candidate interpolation inequality.

    The candidate-lemma SDP uses one shared two-point interpolation inequality
    across the selected singleton slots.  The final slot weights are fixed to
    the sparse SDP multipliers already found on the same grid point; the SDP
    then chooses the derived inequalities themselves.
    """

    if not 0.0 <= mu < L:
        raise ValueError("Expected 0 <= mu < L.")
    if gamma <= 0.0:
        raise ValueError("Expected gamma > 0.")

    candidate_indices = tuple(candidate_indices)
    labels, value_vectors, original_matrices = build_gd_sdp(L=L, mu=mu, gamma=gamma)
    interpolation_pairs = _gd_interpolation_pairs(gamma)
    multipliers = np.asarray(multipliers, dtype=float)
    if multipliers.shape != (len(labels),):
        raise ValueError(f"Expected {len(labels)} multipliers.")
    if any(idx < 0 or idx >= len(labels) for idx in candidate_indices):
        raise ValueError(f"Candidate indices must be in [0, {len(labels) - 1}].")

    eta = np.maximum(multipliers[np.array(candidate_indices)], 0.0)
    if not np.any(eta > 1e-10):
        raise ValueError("At least one selected candidate multiplier must be positive.")
    target_rate = closed_form_rate(gamma, L=L, mu=mu) if target_rate is None else float(target_rate)
    value_balance = target_rate * np.array([1.0, 0.0]) - np.array([0.0, 1.0])
    for eta_i, idx in zip(eta, candidate_indices):
        value_balance = value_balance + eta_i * value_vectors[idx]
    value_residual_norm = float(np.max(np.abs(value_balance)))
    if value_residual_norm > value_residual_tol:
        raise ValueError(
            "Fixed multipliers do not certify the requested target rate: "
            f"value residual {value_residual_norm:.3e} exceeds {value_residual_tol:.3e}."
        )

    coeff_xx = cp.Variable()
    coeff_xg = cp.Variable()
    coeff_gg = cp.Variable()
    aggregate_slack = np.zeros((3, 3))
    constraints: list[cp.Constraint] = [coeff_xx >= 0.0, coeff_xg <= 0.0, coeff_gg >= 0.0]
    candidate_matrices = []
    validity_slacks = []

    for eta_i, idx in zip(eta, candidate_indices):
        _, source, target = interpolation_pairs[idx]
        base_matrix, xx_matrix, xg_matrix, gg_matrix = _pair_quadratic_parts(source, target)
        candidate_matrix = (
            base_matrix
            + coeff_xx * xx_matrix
            + coeff_xg * xg_matrix
            + coeff_gg * gg_matrix
        )
        absorbed_candidate = eta_i * candidate_matrix
        validity_slack = original_matrices[idx] - candidate_matrix
        aggregate_slack = aggregate_slack + absorbed_candidate
        candidate_matrices.append(candidate_matrix)
        validity_slacks.append(validity_slack)
        constraints.append(validity_slack >> 0.0)

    constraints.append(aggregate_slack >> 0.0)
    problem = cp.Problem(cp.Minimize(cp.trace(aggregate_slack)), constraints)
    problem.solve(solver=solver, verbose=False, warm_start=True)

    if problem.status not in GOOD_STATUSES:
        raise RuntimeError(f"Candidate-lemma SDP failed with status {problem.status}.")

    coeffs = {
        "xx": float(coeff_xx.value),
        "xg": float(coeff_xg.value),
        "gg": float(coeff_gg.value),
    }
    mu_hat = -coeffs["xg"] / (2.0 * coeffs["gg"]) if coeffs["gg"] > 0.0 else np.nan
    L_hat = mu_hat + 1.0 / (2.0 * coeffs["gg"]) if coeffs["gg"] > 0.0 else np.nan
    relation_error = float(abs(coeffs["xx"] - coeffs["gg"] * L_hat * mu_hat)) if np.isfinite(L_hat) else np.nan

    candidate_numeric = [np.asarray(matrix.value, dtype=float) for matrix in candidate_matrices]
    validity_numeric = [np.asarray(slack.value, dtype=float) for slack in validity_slacks]
    aggregate_numeric = np.asarray(aggregate_slack.value, dtype=float)
    _, _, fitted_matrices = build_gd_sdp(L=L_hat, mu=mu_hat, gamma=gamma)
    matrix_fit_error = max(
        float(np.max(np.abs(candidate_numeric[k] - fitted_matrices[idx])))
        for k, idx in enumerate(candidate_indices)
    )
    identifications = [
        identify_interpolation_curvatures(candidate_numeric[k], gamma=gamma, pair_index=idx)
        for k, idx in enumerate(candidate_indices)
    ]
    validity_min_eigenvalue = min(
        float(np.min(np.linalg.eigvalsh(symmetrize(slack))))
        for slack in validity_numeric
    )
    aggregate_eigenvalues = np.linalg.eigvalsh(symmetrize(aggregate_numeric))

    return {
        "status": str(problem.status),
        "solver": solver,
        "gamma": float(gamma),
        "target_rate": target_rate,
        "candidate_indices": candidate_indices,
        "candidate_labels": tuple(labels[idx] for idx in candidate_indices),
        "active_candidate_labels": tuple(
            labels[idx]
            for idx, eta_i in zip(candidate_indices, eta)
            if eta_i > 1e-10
        ),
        "eta": eta,
        "value_residual_norm": value_residual_norm,
        "coefficients": coeffs,
        "L": float(L_hat),
        "mu": float(mu_hat),
        "relation_error": relation_error,
        "matrix_fit_error": matrix_fit_error,
        "identifications": identifications,
        "candidate_matrices": candidate_numeric,
        "validity_slacks": validity_numeric,
        "validity_min_eigenvalue": validity_min_eigenvalue,
        "aggregate_slack": aggregate_numeric,
        "aggregate_eigenvalues": aggregate_eigenvalues,
        "aggregate_rank": int(
            np.sum(
                aggregate_eigenvalues
                > 1e-7 * max(1.0, float(np.max(np.abs(aggregate_eigenvalues))))
            )
        ),
    }
