"""Run the complete BioDimBench experiment pipeline."""

from __future__ import annotations

import argparse
import logging
from pathlib import Path

import pandas as pd

from src.corruptions import generate_candidates
from src.evaluate import compute_aggregate_metrics, compute_error_type_recall, run_sanity_checks
from src.figures import generate_figures
from src.generate_benchmark import generate_benchmark
from src.latex_writer import write_latex_outputs
from src.train_learned_baseline import run_learned_baseline
from src.utils import ensure_output_dirs, method_label
from src.verifiers import run_all_verifiers


MODE_DEFAULTS = {"smoke": 50, "pilot": 200, "full": 500}


def main() -> None:
    args = _parse_args()
    root = Path(__file__).resolve().parent
    output_dirs = ensure_output_dirs(root)

    logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
    n = args.n if args.n is not None else MODE_DEFAULTS[args.mode]
    logging.info("Running BioDimBench mode=%s n=%s seed=%s", args.mode, n, args.seed)

    problems = generate_benchmark(n=n, seed=args.seed)
    candidates = generate_candidates(problems, seed=args.seed)
    problems.to_csv(output_dirs["data"] / "biodimbench_problems.csv", index=False)
    candidates.to_csv(output_dirs["data"] / "biodimbench_candidates.csv", index=False)

    verifier_results = run_all_verifiers(problems, candidates)
    learned_results = run_learned_baseline(problems, candidates, seed=args.seed)
    all_results = pd.concat([verifier_results, learned_results], ignore_index=True)
    run_sanity_checks(all_results)
    all_results.to_csv(output_dirs["metrics"] / "verification_results.csv", index=False)

    aggregate_metrics = compute_aggregate_metrics(all_results)
    error_type_recall = compute_error_type_recall(all_results)
    aggregate_metrics.to_csv(output_dirs["metrics"] / "aggregate_metrics.csv", index=False)
    error_type_recall.to_csv(output_dirs["metrics"] / "error_type_recall.csv", index=False)

    generate_figures(aggregate_metrics, error_type_recall, output_dirs["figures"])
    write_latex_outputs(
        aggregate_metrics,
        error_type_recall,
        n_problems=len(problems),
        n_candidates=len(candidates),
        latex_dir=output_dirs["latex"],
    )

    _print_summary(problems, candidates, aggregate_metrics, output_dirs)


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run the BioDimBench experiment pipeline.")
    parser.add_argument("--n", type=int, default=None, help="Number of base problems to generate.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic generation.")
    parser.add_argument(
        "--mode",
        choices=sorted(MODE_DEFAULTS),
        default="full",
        help="Convenience mode. Used for default n when --n is omitted.",
    )
    return parser.parse_args()


def _print_summary(
    problems: pd.DataFrame, candidates: pd.DataFrame, aggregate_metrics: pd.DataFrame, output_dirs: dict[str, Path]
) -> None:
    main_metrics = aggregate_metrics[
        (aggregate_metrics["split"] == "all") & (aggregate_metrics["method"] != "learned_baseline")
    ].copy()
    best = main_metrics.sort_values("invalid_f1", ascending=False).iloc[0]

    print("\nBioDimBench experiment complete")
    print(f"Problems: {len(problems)}")
    print(f"Candidate solutions: {len(candidates)}")
    print(f"Best verifier by invalid F1: {method_label(best['method'])} ({best['invalid_f1']:.3f})")
    print(f"LaTeX tables: {output_dirs['latex'] / 'main_results_table.tex'}")
    print(f"Error recall table: {output_dirs['latex'] / 'error_recall_table.tex'}")
    print(f"Figures: {output_dirs['figures']}")
    print(f"Metrics: {output_dirs['metrics']}")


if __name__ == "__main__":
    main()
