"""LaTeX tables and paper paragraph generation."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd

from .utils import ERROR_TYPE_DISPLAY, INVALID_CANDIDATE_TYPES, METHOD_ORDER, method_label, percent


def write_latex_outputs(
    metrics: pd.DataFrame,
    error_recall: pd.DataFrame,
    n_problems: int,
    n_candidates: int,
    latex_dir: Path,
) -> None:
    """Write LaTeX tables and result/discussion paragraphs."""

    latex_dir.mkdir(parents=True, exist_ok=True)
    (latex_dir / "main_results_table.tex").write_text(_main_results_table(metrics), encoding="utf-8")
    (latex_dir / "error_recall_table.tex").write_text(_error_recall_table(error_recall), encoding="utf-8")
    (latex_dir / "results_paragraph.txt").write_text(
        _results_paragraph(metrics, error_recall, n_problems, n_candidates),
        encoding="utf-8",
    )
    (latex_dir / "discussion_paragraph.txt").write_text(
        _discussion_paragraph(metrics, error_recall),
        encoding="utf-8",
    )


def _main_results_table(metrics: pd.DataFrame) -> str:
    rows = []
    table_metrics = _ordered_all_methods(metrics)
    for row in table_metrics.itertuples():
        rows.append(
            f"{method_label(row.method)} & {_fmt(row.accuracy)} & {_fmt(row.invalid_precision)} & "
            f"{_fmt(row.invalid_recall)} & {_fmt(row.invalid_f1)} \\\\"
        )

    body = "\n".join(rows)
    return (
        "\\begin{table}[t]\n"
        "\\centering\n"
        "\\caption{Verification performance on BioDimBench. Values are computed over all candidate solutions.}\n"
        "\\label{tab:main_results}\n"
        "\\resizebox{\\linewidth}{!}{%\n"
        "\\begin{tabular}{lcccc}\n"
        "\\toprule\n"
        "Method & Accuracy $\\uparrow$ & Invalid Precision $\\uparrow$ & Invalid Recall $\\uparrow$ & Invalid F1 $\\uparrow$ \\\\\n"
        "\\midrule\n"
        f"{body}\n"
        "\\bottomrule\n"
        "\\end{tabular}%\n"
        "}\n"
        "\\end{table}\n"
    )


def _error_recall_table(error_recall: pd.DataFrame) -> str:
    rows = []
    all_recall = error_recall[error_recall["split"] == "all"].copy()
    for method in [method for method in METHOD_ORDER if method != "learned_baseline"]:
        subset = all_recall[all_recall["method"] == method]
        if subset.empty:
            continue
        values = []
        for candidate_type in INVALID_CANDIDATE_TYPES:
            match = subset[subset["candidate_type"] == candidate_type]
            values.append(_fmt(match["recall"].iloc[0]) if not match.empty else "--")
        rows.append(f"{method_label(method)} & " + " & ".join(values) + " \\\\")

    body = "\n".join(rows)
    return (
        "\\begin{table}[t]\n"
        "\\centering\n"
        "\\caption{Recall by corrupted solution type. Higher values indicate better detection of invalid biomedical reasoning.}\n"
        "\\label{tab:error_recall}\n"
        "\\resizebox{\\linewidth}{!}{%\n"
        "\\begin{tabular}{lccccc}\n"
        "\\toprule\n"
        "Method & Arithmetic & Formula & Unit & Conversion & Plausible Scalar \\\\\n"
        "\\midrule\n"
        f"{body}\n"
        "\\bottomrule\n"
        "\\end{tabular}%\n"
        "}\n"
        "\\end{table}\n"
    )


def _results_paragraph(
    metrics: pd.DataFrame, error_recall: pd.DataFrame, n_problems: int, n_candidates: int
) -> str:
    answer = _metric_row(metrics, "answer_only", "all")
    numeric = _metric_row(metrics, "numeric_plus_unit", "all")
    step = _metric_row(metrics, "step_aware", "all")
    best = _ordered_all_methods(metrics).sort_values("invalid_f1", ascending=False).iloc[0]
    learned = _metric_row(metrics, "learned_baseline", "test", required=False)

    pieces = [
        f"Across {n_problems} generated biomedical problems and {n_candidates} candidate solutions, "
        f"answer-only checking achieved accuracy {_fmt(answer.accuracy)}, invalid recall {_fmt(answer.invalid_recall)}, "
        f"and invalid F1 {_fmt(answer.invalid_f1)}.",
        f"The unit-aware numeric-plus-unit verifier achieved accuracy {_fmt(numeric.accuracy)}, "
        f"invalid recall {_fmt(numeric.invalid_recall)}, and invalid F1 {_fmt(numeric.invalid_f1)}.",
        f"The step-aware verifier covered {percent(step.coverage)} of candidate solutions and achieved "
        f"invalid recall {_fmt(step.invalid_recall)} with invalid F1 {_fmt(step.invalid_f1)}.",
        f"Among non-learned verifiers, the best invalid F1 was {_fmt(best.invalid_f1)} from {method_label(best.method)}.",
    ]
    if learned is not None:
        pieces.append(
            f"The optional learned baseline was evaluated on {int(learned.n_evaluated)} held-out candidate rows "
            f"split by problem_id and achieved invalid F1 {_fmt(learned.invalid_f1)}."
        )
    return " ".join(pieces) + "\n"


def _discussion_paragraph(metrics: pd.DataFrame, error_recall: pd.DataFrame) -> str:
    answer_plausible = _error_recall(error_recall, "answer_only", "plausible_scalar_wrong_unit")
    numeric_unit = _error_recall(error_recall, "numeric_plus_unit", "wrong_unit")
    numeric_plausible = _error_recall(error_recall, "numeric_plus_unit", "plausible_scalar_wrong_unit")
    step_formula = _error_recall(error_recall, "step_aware", "wrong_formula")
    step_conversion = _error_recall(error_recall, "step_aware", "missing_conversion")

    answer_miss = 1 - answer_plausible if np.isfinite(answer_plausible) else np.nan
    return (
        f"These results support the paper's central claim: answer-only checking missed "
        f"{percent(answer_miss)} of plausible scalar wrong-unit corruptions because it ignored dimensions. "
        f"By requiring dimensional compatibility and conversion before numeric comparison, numeric-plus-unit rejected "
        f"{percent(numeric_unit)} of wrong-unit cases and {percent(numeric_plausible)} of plausible scalar wrong-unit cases. "
        f"The structured step-aware verifier rejected {percent(step_formula)} of formula corruptions and "
        f"{percent(step_conversion)} of missing-conversion corruptions in this generated benchmark. "
        "This indicates that lightweight dimensional verification can act as a practical safety layer for scientific agents, "
        "especially when final answers contain right-looking numbers paired with inconsistent biomedical units.\n"
    )


def _ordered_all_methods(metrics: pd.DataFrame) -> pd.DataFrame:
    frame = metrics[(metrics["split"] == "all") & (metrics["method"] != "learned_baseline")].copy()
    frame["method_order"] = frame["method"].map({m: i for i, m in enumerate(METHOD_ORDER)}).fillna(99)
    return frame.sort_values("method_order").drop(columns=["method_order"]).reset_index(drop=True)


def _metric_row(metrics: pd.DataFrame, method: str, split: str, required: bool = True):
    match = metrics[(metrics["method"] == method) & (metrics["split"] == split)]
    if match.empty:
        if required:
            raise KeyError(f"Missing metrics for {method}/{split}")
        return None
    return match.iloc[0]


def _error_recall(error_recall: pd.DataFrame, method: str, candidate_type: str) -> float:
    match = error_recall[
        (error_recall["method"] == method)
        & (error_recall["split"] == "all")
        & (error_recall["candidate_type"] == candidate_type)
    ]
    if match.empty:
        return np.nan
    return float(match["recall"].iloc[0])


def _fmt(value: float) -> str:
    if value is None or not np.isfinite(value):
        return "--"
    return f"{float(value):.3f}"
