import argparse
import csv
import json
from pathlib import Path
from statistics import mean, pstdev
from typing import Any

try:
    from proofstate_common import REPRESENTATIONS, load_jsonl, normalize_entry
    from run_experiments import evaluate_representation, parse_representations
except ImportError:
    from scripts.proofstate_common import REPRESENTATIONS, load_jsonl, normalize_entry
    from scripts.run_experiments import evaluate_representation, parse_representations


METRICS = ["accuracy", "macro_f1", "top1_accuracy", "top3_accuracy", "top5_accuracy"]


def parse_seeds(value: str) -> list[int]:
    return [int(item.strip()) for item in value.split(",") if item.strip()]


def flatten_result(result: dict[str, Any], seed: int) -> list[dict[str, Any]]:
    rows = []
    for model, metrics in result["baselines"].items():
        row = {
            "seed": seed,
            "representation": result["representation"],
            "model": model,
        }
        row.update(metrics)
        rows.append(row)
    return rows


def write_rows(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["seed", "representation", "model", *METRICS])
        writer.writeheader()
        writer.writerows(rows)


def aggregate(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
    groups: dict[tuple[str, str], list[dict[str, Any]]] = {}
    for row in rows:
        groups.setdefault((row["representation"], row["model"]), []).append(row)

    output = []
    for (representation, model), group in sorted(groups.items()):
        agg = {"representation": representation, "model": model, "n_seeds": len(group)}
        for metric in METRICS:
            values = [float(row[metric]) for row in group]
            agg[f"{metric}_mean"] = mean(values)
            agg[f"{metric}_std"] = pstdev(values) if len(values) > 1 else 0.0
        output.append(agg)
    return output


def write_aggregate(rows: list[dict[str, Any]], path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = ["representation", "model", "n_seeds"]
    for metric in METRICS:
        fieldnames.extend([f"{metric}_mean", f"{metric}_std"])
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    parser = argparse.ArgumentParser(description="Run representation ablations over multiple theorem split seeds.")
    parser.add_argument("--data", type=Path, default=Path("data/leandojo_steps_checked.jsonl"))
    parser.add_argument("--representations", default="all")
    parser.add_argument("--seeds", default="0,1,2,3,4,42")
    parser.add_argument("--test-ratio", type=float, default=0.3)
    parser.add_argument("--retrieved-premise-top-n", type=int, default=8)
    parser.add_argument("--output-dir", type=Path, default=Path("results/classification"))
    parser.add_argument("--table-output", type=Path, default=Path("results/tables/classification_multiseed.csv"))
    parser.add_argument("--aggregate-output", type=Path, default=Path("results/tables/classification_aggregate.csv"))
    args = parser.parse_args()

    rows = [normalize_entry(row, include_optional=True) for row in load_jsonl(args.data)]
    representations = parse_representations(args.representations)
    seeds = parse_seeds(args.seeds)
    all_rows = []
    dataset_name = args.data.stem
    args.output_dir.mkdir(parents=True, exist_ok=True)

    for seed in seeds:
        for representation in representations:
            result = evaluate_representation(
                rows,
                representation,
                args.test_ratio,
                seed,
                retrieved_premise_top_n=args.retrieved_premise_top_n,
            )
            output_path = args.output_dir / f"{dataset_name}_{representation}_seed{seed}.json"
            with output_path.open("w", encoding="utf-8") as f:
                json.dump(result, f, indent=2, ensure_ascii=False, sort_keys=True)
            all_rows.extend(flatten_result(result, seed))
            print(f"{representation} seed={seed}: wrote {output_path}")

    write_rows(all_rows, args.table_output)
    write_aggregate(aggregate(all_rows), args.aggregate_output)
    print(f"Wrote multiseed table to: {args.table_output}")
    print(f"Wrote aggregate table to: {args.aggregate_output}")


if __name__ == "__main__":
    main()
