"""CVXPY SDPs for the fast-gradient-method examples.

Everything here is built from plain NumPy arrays and direct CVXPY problems.
For the FGM horizon ``N``, the reduced Gram basis is

    z = (x_0 - x_*, g_0, ..., g_{n-1}, g_out),

and the function-value basis is

    s = (f_0-f_*, ..., f_{n-1}-f_*, f_n-f_*).
"""

from __future__ import annotations

from pathlib import Path
from typing import Callable, Iterable, Mapping

import cvxpy as cp
import numpy as np

from .sdp_utils import smooth_strongly_convex_interpolation_matrix, square_matrix, symmetrize, unit


GOOD_STATUSES = {cp.OPTIMAL, cp.OPTIMAL_INACCURATE, "optimal", "optimal_inaccurate"}
THRESHOLD = 1e-7
CARDINALITY_ACTIVITY_THRESHOLD = 1e-3
RHO_TOL = 5e-8
NORMALIZATION_CLIP = 1e12
REWEIGHT_EPS = 1e-3
REWEIGHT_ITERS = 12
CAPPED_TAUS = (1.0, 0.5, 0.25, 0.1, 0.05, 0.02)
CAPPED_ITERS_PER_TAU = 4
CAPPED_PLATEAU_WEIGHT = 0.0
N_PAPER = 3
L_PAPER = 1.0
FGM_TABLE_COLUMN_GAP = "0.6em"
FGM_TABLE_ROW_GAP = "1.6em"


def rho_cap(n: int) -> float:
    return 2.0 / float(n * n + 5 * n + 2) + 1e-9


def chain_active_multiplier_labels(n: int) -> tuple[str, ...]:
    labels = [f"xs->y{k}" for k in range(n)]
    labels.extend(f"y{k}->y{k + 1}" for k in range(n - 1))
    labels.append(f"y{n - 1}->x{n}")
    return tuple(labels)


def point_names(n: int) -> tuple[str, ...]:
    return ("xs", *(f"y{k}" for k in range(n)), f"x{n}")


def build_points(n: int, L: float = 1.0) -> list[tuple[str, np.ndarray, np.ndarray, np.ndarray]]:
    """Return coefficients for ``xs, y0, ..., y_{n-1}, x_n``."""

    if n < 1:
        raise ValueError("Expected n >= 1.")
    if L <= 0.0:
        raise ValueError("Expected L > 0.")

    q_dim = n + 2
    v_dim = n + 1
    x0 = unit(q_dim, 0)
    zero_q = np.zeros(q_dim)
    zero_v = np.zeros(v_dim)
    points = [("xs", zero_q.copy(), zero_q.copy(), zero_v.copy())]

    x_prev = x0.copy()
    y_curr = x0.copy()
    x_next = x0.copy()
    for k in range(n):
        gk = unit(q_dim, 1 + k)
        fk = unit(v_dim, k)
        points.append((f"y{k}", y_curr.copy(), gk, fk))
        x_next = y_curr - (1.0 / L) * gk
        if k < n - 1:
            beta = k / (k + 3.0)
            y_next = x_next + beta * (x_next - x_prev)
            x_prev = x_next
            y_curr = y_next

    points.append((f"x{n}", x_next.copy(), unit(q_dim, n + 1), unit(v_dim, n)))
    return points


def build_fgm_arrays(
    n: int = N_PAPER,
    L: float = L_PAPER,
) -> tuple[tuple[str, ...], tuple[np.ndarray, ...], tuple[np.ndarray, ...], np.ndarray, np.ndarray]:
    """Build the FGM interpolation arrays used by all examples."""

    points = build_points(n=n, L=L)
    q_dim = n + 2
    v_dim = n + 1
    labels: list[str] = []
    matrices: list[np.ndarray] = []
    vectors: list[np.ndarray] = []

    for source in points:
        for target in points:
            source_label, _, _, source_f = source
            target_label, _, _, target_f = target
            if source_label == target_label:
                continue
            _, source_x, source_g, _ = source
            _, target_x, target_g, _ = target
            labels.append(f"{source_label}->{target_label}")
            matrices.append(
                smooth_strongly_convex_interpolation_matrix(source_x, source_g, target_x, target_g, L=L)
            )
            vectors.append(target_f - source_f)

    return (
        tuple(labels),
        tuple(symmetrize(matrix) for matrix in matrices),
        tuple(vector.reshape(-1, 1) for vector in vectors),
        square_matrix(unit(q_dim, 0)),
        np.concatenate([np.zeros(v_dim - 1), [-1.0]]).reshape(-1, 1),
    )


def multiplier_balances(
    matrices: tuple[np.ndarray, ...],
    vectors: tuple[np.ndarray, ...],
    Q_init: np.ndarray,
    q_const: np.ndarray,
    rho: cp.Expression,
    lambdas: cp.Expression,
    slack: cp.Expression,
) -> tuple[cp.Expression, cp.Expression]:
    """Return the shared matrix and value balances for the FGM dual SDP."""

    matrix_balance = rho * np.asarray(Q_init, dtype=float) - slack
    value_balance = np.asarray(q_const, dtype=float).copy()
    for idx, (matrix_i, vector_i) in enumerate(zip(matrices, vectors)):
        matrix_balance = matrix_balance + lambdas[idx] * matrix_i
        value_balance = value_balance + lambdas[idx] * vector_i
    return matrix_balance, value_balance


def solve_result(problem: cp.Problem, rho: cp.Variable, lambdas: cp.Variable) -> tuple[str, float | None, np.ndarray | None]:
    if problem.status not in GOOD_STATUSES or rho.value is None or lambdas.value is None:
        return str(problem.status), None, None
    values = np.maximum(np.asarray(lambdas.value, dtype=float).reshape(-1), 0.0)
    values[np.abs(values) < 1e-12] = 0.0
    return str(problem.status), float(rho.value), values


def solve_multiplier_sdp(
    labels: tuple[str, ...],
    matrices: tuple[np.ndarray, ...],
    vectors: tuple[np.ndarray, ...],
    Q_init: np.ndarray,
    q_const: np.ndarray,
    solver: str = "MOSEK",
    objective: str = "rate",
    cap: float | None = None,
    weights: np.ndarray | None = None,
    active_multipliers: Iterable[int] | None = None,
    multiplier_index: int | None = None,
) -> tuple[str, float | None, np.ndarray | None]:
    """Solve the shared FGM dual SDP with a small objective/restriction switch."""

    dim = np.asarray(Q_init).shape[0]
    lambdas = cp.Variable(len(labels), nonneg=True)
    rho = cp.Variable()
    slack = cp.Variable((dim, dim), symmetric=True)

    matrix_balance, value_balance = multiplier_balances(matrices, vectors, Q_init, q_const, rho, lambdas, slack)
    constraints: list[cp.Constraint] = [slack >> 0.0, matrix_balance == 0.0, value_balance == 0.0]

    if cap is not None:
        constraints.append(rho <= float(cap))

    if active_multipliers is not None:
        active_set = set(active_multipliers)
        if any(idx < 0 or idx >= len(labels) for idx in active_set):
            raise ValueError(f"Active multiplier indices must be in [0, {len(labels) - 1}].")
        inactive = sorted(set(range(len(labels))) - active_set)
        if inactive:
            constraints.append(lambdas[inactive] == 0.0)

    if objective == "rate":
        problem = cp.Problem(cp.Minimize(rho), constraints)
    elif objective == "weighted_l1":
        if weights is None:
            raise ValueError("weights are required for objective='weighted_l1'.")
        problem = cp.Problem(cp.Minimize(np.asarray(weights, dtype=float) @ lambdas), constraints)
    elif objective == "max_multiplier":
        if multiplier_index is None:
            raise ValueError("multiplier_index is required for objective='max_multiplier'.")
        if multiplier_index < 0 or multiplier_index >= len(labels):
            raise ValueError(f"multiplier_index must be in [0, {len(labels) - 1}].")
        problem = cp.Problem(cp.Maximize(lambdas[int(multiplier_index)]), constraints)
    else:
        raise ValueError("objective must be 'rate', 'weighted_l1', or 'max_multiplier'.")

    problem.solve(solver=solver, verbose=False, warm_start=True)
    return solve_result(problem, rho, lambdas)


def solve_l1_heuristic(
    sdp: tuple[tuple[str, ...], tuple[np.ndarray, ...], tuple[np.ndarray, ...], np.ndarray, np.ndarray],
    solver: str = "MOSEK",
    objective: str = "weighted_l1",
    cap: float | None = None,
    scale: np.ndarray | None = None,
    iterations: int = REWEIGHT_ITERS,
    eps: float = REWEIGHT_EPS,
) -> tuple[str, float | None, np.ndarray | None]:
    """Solve one of the small L1-based multiplier heuristics used in the notebook."""

    labels = sdp[0]

    def solve_weights(weights: np.ndarray) -> tuple[str, float | None, np.ndarray | None]:
        return solve_multiplier_sdp(
            *sdp,
            solver=solver,
            objective="weighted_l1",
            cap=cap,
            weights=weights,
        )

    if objective == "weighted_l1":
        weights = np.ones(len(labels)) if scale is None else 1.0 / np.maximum(np.asarray(scale, dtype=float), 1e-10)
        return solve_weights(weights)

    if objective == "reweighted_l1":
        safe_scale = None if scale is None else np.maximum(np.asarray(scale, dtype=float), 1e-10)
        weights = np.ones(len(labels)) if safe_scale is None else 1.0 / safe_scale
        current: tuple[str, float | None, np.ndarray | None] = ("not_run", None, None)
        for _ in range(iterations):
            current = solve_weights(weights)
            _, _, lambdas = current
            if lambdas is None:
                return current
            weights = 1.0 / (lambdas + eps) if safe_scale is None else 1.0 / (lambdas + eps * safe_scale)
        return current

    if objective == "capped_l1":
        safe_scale = np.ones(len(labels)) if scale is None else np.maximum(np.asarray(scale, dtype=float), 1e-10)
        current = solve_weights(1.0 / safe_scale)
        _, _, lambdas = current
        if lambdas is None:
            return current

        last_active: tuple[int, ...] | None = None
        for tau in CAPPED_TAUS:
            for _ in range(CAPPED_ITERS_PER_TAU):
                normalized = np.maximum(lambdas, 0.0) / safe_scale
                weights = np.where(normalized < tau, 1.0, CAPPED_PLATEAU_WEIGHT) / safe_scale
                current = solve_weights(weights)
                _, _, lambdas = current
                if lambdas is None:
                    return current
                active = tuple(active_indices(lambdas))
                if active == last_active:
                    break
                last_active = active
        return current

    raise ValueError(
        "objective must be 'weighted_l1', 'reweighted_l1', or 'capped_l1'."
    )


def normalization_bounds(
    sdp: tuple[tuple[str, ...], tuple[np.ndarray, ...], tuple[np.ndarray, ...], np.ndarray, np.ndarray],
    solver: str = "MOSEK",
    cap: float | None = None,
    clip: float = NORMALIZATION_CLIP,
    progress_factory: Callable | None = None,
    progress_label: str | None = None,
) -> np.ndarray:
    """Return per-multiplier upper bounds used to scale the normalized heuristics."""

    labels = sdp[0]
    bounds = np.zeros(len(labels))
    indices: Iterable[int] = range(len(labels))
    if progress_factory is not None:
        indices = progress_factory(indices, total=len(labels), desc=progress_label)
    for idx in indices:
        label = labels[idx]
        status, _, lambdas = solve_multiplier_sdp(
            *sdp,
            solver=solver,
            objective="max_multiplier",
            cap=cap,
            multiplier_index=idx,
        )
        if status.lower() in {"unbounded", "unbounded_inaccurate"}:
            bounds[idx] = clip
        elif lambdas is None:
            raise RuntimeError(f"Failed to compute normalization bound for {label}: {status}")
        else:
            bounds[idx] = min(float(clip), max(0.0, float(lambdas[idx])))
    return np.maximum(bounds, 1e-10)


def multiplier_indices(all_labels: tuple[str, ...], selected_labels: Iterable[str]) -> tuple[int, ...]:
    lookup = {label: idx for idx, label in enumerate(all_labels)}
    selected = tuple(selected_labels)
    missing = [label for label in selected if label not in lookup]
    if missing:
        raise KeyError(f"Unknown labels: {missing}")
    return tuple(lookup[label] for label in selected)


def active_indices(lambdas: np.ndarray, threshold: float = THRESHOLD) -> list[int]:
    return [idx for idx, value in enumerate(np.asarray(lambdas, dtype=float)) if abs(float(value)) > threshold]


def require_cardinality_result(
    name: str,
    result: tuple[str, float | None, np.ndarray | None],
    cap: float,
) -> tuple[str, float, np.ndarray]:
    """Return a checked SDP result for the cardinality sweep."""

    status, rho, lambdas = result
    if status not in GOOD_STATUSES or rho is None or lambdas is None:
        raise RuntimeError(f"{name} solve failed: {status}")
    if rho > cap + RHO_TOL:
        raise RuntimeError(f"{name} returned rho={rho:.12g} above cap={cap:.12g}")
    return status, float(rho), np.asarray(lambdas, dtype=float)


def summarize_cardinality_method(
    name: str,
    result: tuple[str, float | None, np.ndarray | None],
    cap: float,
    activity_threshold: float = CARDINALITY_ACTIVITY_THRESHOLD,
) -> dict[str, object]:
    """Summarize one method row for the FGM cardinality line plot."""

    status, rho, lambdas = require_cardinality_result(name, result, cap)
    return {
        "status": status,
        "rho": rho,
        "active_multipliers": len(active_indices(lambdas, threshold=activity_threshold)),
    }


def solve_cardinality_row(
    n: int,
    solver: str = "MOSEK",
    L: float = L_PAPER,
    activity_threshold: float = CARDINALITY_ACTIVITY_THRESHOLD,
    cap_margin: float = 0.0,
    progress_factory: Callable | None = None,
) -> dict[str, object]:
    """Generate one row of active-multiplier counts for the FGM line plot."""

    sdp = build_fgm_arrays(n=n, L=L)
    labels = sdp[0]
    cap = rho_cap(n)
    solve_cap = cap - cap_margin

    raw = solve_multiplier_sdp(*sdp, solver=solver)
    plain_l1 = solve_l1_heuristic(sdp, solver=solver, objective="weighted_l1", cap=solve_cap)
    reweighted = solve_l1_heuristic(sdp, solver=solver, objective="reweighted_l1", cap=solve_cap)
    normalization_scales = normalization_bounds(
        sdp,
        solver=solver,
        cap=solve_cap,
        progress_factory=progress_factory,
        progress_label=f"SDP solves: normalization bounds (N={n})",
    )
    normalized = solve_l1_heuristic(
        sdp,
        solver=solver,
        objective="reweighted_l1",
        cap=solve_cap,
        scale=normalization_scales,
    )
    capped = solve_l1_heuristic(
        sdp,
        solver=solver,
        objective="capped_l1",
        cap=solve_cap,
        scale=normalization_scales,
    )

    methods = {
        "raw": summarize_cardinality_method("raw", raw, cap, activity_threshold=activity_threshold),
        "l1": summarize_cardinality_method("l1", plain_l1, cap, activity_threshold=activity_threshold),
        "reweighted": summarize_cardinality_method(
            "reweighted",
            reweighted,
            cap,
            activity_threshold=activity_threshold,
        ),
        "normalized": summarize_cardinality_method(
            "normalized",
            normalized,
            cap,
            activity_threshold=activity_threshold,
        ),
        "capped": summarize_cardinality_method("capped", capped, cap, activity_threshold=activity_threshold),
    }
    return {
        "n": n,
        "target": 2 * n,
        "num_multipliers": len(labels),
        "rho_cap": cap,
        "methods": methods,
    }


def generate_cardinality_payload(
    n_max: int,
    solver: str = "MOSEK",
    exact_n_max: int = 4,
    L: float = L_PAPER,
    activity_threshold: float = CARDINALITY_ACTIVITY_THRESHOLD,
    progress_factory: Callable | None = None,
) -> dict[str, object]:
    """Generate the FGM cardinality data consumed by the line plot."""

    rows = [
        solve_cardinality_row(
            n=n,
            solver=solver,
            L=L,
            activity_threshold=activity_threshold,
            progress_factory=progress_factory,
        )
        for n in range(1, n_max + 1)
    ]
    exact = {
        n: {
            "n": n,
            "exact_cardinality": 2 * n if n <= exact_n_max else None,
            "source": "recorded exhaustive verification" if n <= exact_n_max else "conjectured only",
        }
        for n in range(1, n_max + 1)
    }
    return {
        "solver": solver,
        "activity_threshold": activity_threshold,
        "exact_verification": exact,
        "rows": rows,
    }


def latex_point_label(name: str) -> str:
    if name == "xs":
        return r"x_\star"
    if name.startswith("y") and name[1:].isdigit():
        return rf"y_{{{int(name[1:])}}}"
    if name.startswith("x") and name[1:].isdigit():
        return rf"x_{{{int(name[1:])}}}"
    return name


def format_latex_threshold(threshold: float) -> str:
    if threshold < 0.0 or not np.isfinite(threshold):
        raise ValueError(f"Expected a finite nonnegative threshold, got {threshold}.")
    if threshold == 0.0:
        return "0"
    exponent = int(np.floor(np.log10(threshold)))
    mantissa = threshold / (10.0**exponent)
    if np.isclose(mantissa, 1.0, rtol=1e-12, atol=1e-14):
        return rf"10^{{{exponent}}}"
    return rf"{mantissa:.3g}\cdot 10^{{{exponent}}}"


def render_active_multiplier_tables(
    labels: tuple[str, ...],
    multiplier_map: Mapping[str, np.ndarray],
    heuristic_labels: Mapping[str, str],
    heuristic_order: tuple[str, ...],
    n: int = N_PAPER,
    threshold: float = THRESHOLD,
    tables_per_row: int = 2,
    column_gap: str = FGM_TABLE_COLUMN_GAP,
    row_gap: str = FGM_TABLE_ROW_GAP,
    table_position: str = "H",
) -> str:
    """Render the FGM active-multiplier table fragment."""

    if not heuristic_order:
        raise ValueError("Expected at least one heuristic in heuristic_order.")
    threshold_latex = format_latex_threshold(threshold)
    missing_multipliers = [name for name in heuristic_order if name not in multiplier_map]
    if missing_multipliers:
        raise ValueError(f"Missing multiplier arrays for methods: {missing_multipliers}.")
    missing_titles = [name for name in heuristic_order if name not in heuristic_labels]
    if missing_titles:
        raise ValueError(f"Missing display labels for methods: {missing_titles}.")

    points = point_names(n)
    column_spec = "@{}l" + "c" * len(points) + "@{}"
    header = " & ".join([""] + [f"${latex_point_label(point)}$" for point in points]) + r" \\"
    method_tables = []
    for name in heuristic_order:
        values = np.asarray(multiplier_map[name], dtype=float)
        if values.ndim != 1 or values.shape[0] != len(labels):
            raise ValueError(f"Expected {len(labels)} multipliers for {name}, got shape {values.shape}.")
        active_labels = {labels[idx] for idx in active_indices(values, threshold=threshold)}
        lines = [
            rf"\begin{{tabular}}[t]{{{column_spec}}}",
            rf"\multicolumn{{{len(points) + 1}}}{{c}}{{{heuristic_labels[name]} (total {len(active_labels)})}} \\",
            r"\toprule",
            header,
            r"\midrule",
        ]
        for source in points:
            cells = []
            for target in points:
                if source == target:
                    cells.append("--")
                elif f"{source}->{target}" in active_labels:
                    cells.append(r"$\bullet$")
                else:
                    cells.append("")
            lines.append(" & ".join([f"${latex_point_label(source)}$", *cells]) + r" \\")
        lines.extend([r"\bottomrule", r"\end{tabular}"])
        method_tables.append("\n".join(lines))

    rows = []
    for first in range(0, len(method_tables), tables_per_row):
        row = rf"\hspace{{{column_gap}}}".join(method_tables[first : first + tables_per_row])
        rows.append("\n".join([r"\noindent\makebox[\textwidth][c]{%", row, "}"]))
    table_grid = f"\n\\par\\vspace{{{row_gap}}}\n".join(rows)

    return "\n".join(
        [
            "% Generated by examples.fgm.render_active_multiplier_tables; do not edit by hand.",
            rf"\begin{{table}}[{table_position}]",
            r"\centering",
            r"\normalsize",
            r"\renewcommand{\arraystretch}{1.15}",
            r"\setlength{\tabcolsep}{4.0pt}",
            rf"\caption{{Active interpolation-multiplier patterns for the FGM $N={n}$ example. "
            rf"Rows are sources and columns are targets; a bullet marks "
            rf"$|\lambda_{{i\to j}}|>{threshold_latex}$ and dashes mark self-pairs.}}",
            r"\label{tab:fgm-active-multiplier-patterns}",
            table_grid,
            r"\end{table}",
        ]
    )


def write_active_multiplier_tables(
    labels: tuple[str, ...],
    multiplier_map: Mapping[str, np.ndarray],
    heuristic_labels: Mapping[str, str],
    heuristic_order: tuple[str, ...],
    active_tables_path: Path,
    n: int = N_PAPER,
    threshold: float = THRESHOLD,
) -> Path:
    """Write the FGM active-multiplier table fragment."""

    fragment = render_active_multiplier_tables(
        labels,
        multiplier_map,
        heuristic_labels,
        heuristic_order,
        n=n,
        threshold=threshold,
    )

    active_tables_path.parent.mkdir(parents=True, exist_ok=True)
    active_tables_path.write_text(fragment + "\n")
    return active_tables_path
