from __future__ import annotations

import argparse
import sys
from pathlib import Path


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="GPU-oriented runner for the KataGo win-probability claim-consistency experiment"
    )
    parser.add_argument("--require-gpu", action="store_true", default=False)
    parser.add_argument("--cpu-smoke-ok", action="store_true", default=False)
    parser.add_argument("--train-path", type=str, default="")
    parser.add_argument("--eval-path", type=str, default="")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--d-model", type=int, default=256)
    parser.add_argument("--n-layers", type=int, default=4)
    parser.add_argument("--n-heads", type=int, default=8)
    parser.add_argument("--d-ff", type=int, default=1024)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--consistency-weight", type=float, default=0.5)
    parser.add_argument("--max-seq-len", type=int, default=256)
    parser.add_argument("--max-position-tokens", type=int, default=196)
    parser.add_argument("--counterfactual-samples", type=int, default=256)
    parser.add_argument("--output-csv", type=str, default="katago_winprob_results.csv")
    parser.add_argument("--output-markdown", type=str, default="")
    parser.add_argument("--variants", nargs="+", default=[])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--smoke-test", action="store_true", default=False)
    return parser.parse_args()


def gpu_check(require_gpu: bool, cpu_smoke_ok: bool, smoke_test: bool) -> None:
    import torch

    cuda_available = torch.cuda.is_available()
    if require_gpu and not cuda_available:
        print(
            "\n[ERROR] --require-gpu is set but CUDA is not available.\n"
            "Run on a CUDA-capable machine for full training, or use --smoke-test --cpu-smoke-ok for CPU validation.\n",
            file=sys.stderr,
        )
        sys.exit(1)
    if not cuda_available:
        if smoke_test and cpu_smoke_ok:
            print("[WARNING] No CUDA GPU detected. Running CPU smoke test because --cpu-smoke-ok is set.")
        else:
            print(
                "\n[WARNING] No CUDA GPU detected and --require-gpu is not set.\n"
                "Proceeding anyway, but a full training run on CPU may be slow.\n"
            )
    else:
        print(f"[GPU OK] CUDA available: {torch.cuda.device_count()} device(s). Primary: {torch.cuda.get_device_name(0)}")


def main() -> None:
    args = parse_args()
    gpu_check(args.require_gpu, args.cpu_smoke_ok, args.smoke_test)

    from katago_winprob_experiment import Config, VARIANT_NAMES, run_experiment

    variants = tuple(args.variants) if args.variants else VARIANT_NAMES
    cfg = Config(
        train_path=args.train_path,
        eval_path=args.eval_path,
        max_seq_len=args.max_seq_len,
        max_position_tokens=args.max_position_tokens,
        batch_size=min(args.batch_size, 8) if args.smoke_test else args.batch_size,
        epochs=1 if args.smoke_test else args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        d_model=args.d_model,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        d_ff=args.d_ff,
        dropout=args.dropout,
        consistency_weight=args.consistency_weight,
        counterfactual_samples=min(args.counterfactual_samples, 32) if args.smoke_test else args.counterfactual_samples,
        seed=args.seed,
        output_csv=args.output_csv,
        output_markdown=args.output_markdown,
        variants=variants,
        smoke_test=args.smoke_test,
    )

    print("\n[INFO] Starting KataGo win-probability experiment...")
    print(f" train_path: {cfg.train_path or '<auto-generated smoke data>'}")
    print(f" eval_path: {cfg.eval_path or '<auto-generated smoke data>'}")
    print(f" epochs: {cfg.epochs}")
    print(f" batch_size: {cfg.batch_size}")
    print(f" lr: {cfg.lr}")
    print(f" weight_decay: {cfg.weight_decay}")
    print(f" model: d_model={cfg.d_model}, layers={cfg.n_layers}, heads={cfg.n_heads}, d_ff={cfg.d_ff}")
    print(f" max_seq_len: {cfg.max_seq_len}")
    print(f" max_position_tokens: {cfg.max_position_tokens}")
    print(f" consistency_weight: {cfg.consistency_weight}")
    print(f" variants: {', '.join(cfg.variants)}")
    print(f" output_csv: {cfg.output_csv}")
    print(f" output_markdown: {cfg.output_markdown or str(Path(cfg.output_csv).with_suffix('.md'))}")
    print(f" smoke_test: {cfg.smoke_test}\n")

    df = run_experiment(cfg)
    print(df.to_string(index=False))
    print(f"\n[DONE] Saved results to {cfg.output_csv} and {cfg.output_markdown or str(Path(cfg.output_csv).with_suffix('.md'))}")


if __name__ == "__main__":
    main()
