from __future__ import annotations

import argparse
import sys


def parse_args():
    parser = argparse.ArgumentParser(
        description="GPU-oriented runner for the generated-rationale scalar-verifier experiment"
    )
    parser.add_argument("--require-gpu", action="store_true", default=False)
    parser.add_argument("--cpu-smoke-ok", action="store_true", default=False)
    parser.add_argument("--num-train", type=int, default=2048)
    parser.add_argument("--num-eval", type=int, default=512)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--d-model", type=int, default=128)
    parser.add_argument("--n-layers", type=int, default=2)
    parser.add_argument("--n-heads", type=int, default=4)
    parser.add_argument("--d-ff", type=int, default=256)
    parser.add_argument("--consistency-weight", type=float, default=0.5)
    parser.add_argument("--output-csv", type=str, default="generated_rationale_scalar_results.csv")
    parser.add_argument("--smoke-test", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()


def gpu_check(require_gpu: bool, cpu_smoke_ok: bool, smoke_test: bool):
    import torch

    cuda_available = torch.cuda.is_available()
    if require_gpu:
        if not cuda_available:
            print(
                "\n[ERROR] --require-gpu is set but CUDA is not available.\n"
                "Run on a CUDA-capable machine for full training, or use --smoke-test --cpu-smoke-ok for CPU validation.\n",
                file=sys.stderr,
            )
            sys.exit(1)
        print(f"[GPU OK] CUDA available: {torch.cuda.device_count()} device(s). Primary: {torch.cuda.get_device_name(0)}")
        return

    if not cuda_available:
        if smoke_test and cpu_smoke_ok:
            print("[WARNING] No CUDA GPU detected. Running CPU smoke test because --cpu-smoke-ok is set.")
        else:
            print(
                "\n[WARNING] No CUDA GPU detected and --require-gpu is not set.\n"
                "Proceeding anyway, but full training on CPU may be slow.\n"
            )
    else:
        print(f"[GPU OK] CUDA available. Primary device: {torch.cuda.get_device_name(0)}")


def main():
    args = parse_args()
    gpu_check(args.require_gpu, args.cpu_smoke_ok, args.smoke_test)

    from generated_rationale_scalar_verifier_experiment import Config, main as run_main

    if args.smoke_test:
        cfg = Config(
            num_train=256,
            num_eval=64,
            batch_size=min(args.batch_size, 32),
            epochs=2,
            lr=args.lr,
            d_model=args.d_model,
            n_layers=args.n_layers,
            n_heads=args.n_heads,
            d_ff=args.d_ff,
            consistency_weight=args.consistency_weight,
            output_csv=args.output_csv,
            seed=args.seed,
        )
        print("[SMOKE TEST] Using tiny settings for validation.")
    else:
        cfg = Config(
            num_train=args.num_train,
            num_eval=args.num_eval,
            batch_size=args.batch_size,
            epochs=args.epochs,
            lr=args.lr,
            d_model=args.d_model,
            n_layers=args.n_layers,
            n_heads=args.n_heads,
            d_ff=args.d_ff,
            consistency_weight=args.consistency_weight,
            output_csv=args.output_csv,
            seed=args.seed,
        )

    print("\n[INFO] Starting generated-rationale scalar-verifier experiment...")
    print(f" num_train: {cfg.num_train}")
    print(f" num_eval: {cfg.num_eval}")
    print(f" epochs: {cfg.epochs}")
    print(f" batch_size: {cfg.batch_size}")
    print(f" lr: {cfg.lr}")
    print(f" model: d_model={cfg.d_model}, layers={cfg.n_layers}, heads={cfg.n_heads}, d_ff={cfg.d_ff}")
    print(f" consistency_weight: {cfg.consistency_weight}")
    print(f" output_csv: {cfg.output_csv}")
    print(f" smoke_test: {args.smoke_test}\n")

    run_main(cfg)
    print("\n[DONE]")


if __name__ == "__main__":
    main()
