from __future__ import annotations

import argparse
import sys


def parse_args():
    p = argparse.ArgumentParser(description="GPU runner for scratch FEVER transformer experiment")
    p.add_argument("--require-gpu", action="store_true", default=False)
    p.add_argument("--cpu-smoke-ok", action="store_true", default=False)
    p.add_argument("--train-samples", type=int, default=50000)
    p.add_argument("--eval-samples", type=int, default=5000)
    p.add_argument("--max-seq-len", type=int, default=256)
    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--batch-size", type=int, default=32)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--weight-decay", type=float, default=0.01)
    p.add_argument("--consistency-loss-weight", type=float, default=0.5)
    p.add_argument("--warmup-steps", type=int, default=500)
    p.add_argument("--d-model", type=int, default=256)
    p.add_argument("--n-heads", type=int, default=8)
    p.add_argument("--n-layers", type=int, default=4)
    p.add_argument("--d-ff", type=int, default=1024)
    p.add_argument("--dropout", type=float, default=0.1)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--output-csv", type=str, default="results_fever_scratch.csv")
    p.add_argument(
        "--variants", nargs="+",
        default=["no_consistency_loss", "evidence_only_pooling", "full_sequence_pooling", "claim_only_pooling"],
    )
    p.add_argument("--smoke-test", action="store_true", default=False)
    return p.parse_args()


def gpu_check(require_gpu: bool, cpu_smoke_ok: bool, smoke_test: bool):
    import torch
    cuda_available = torch.cuda.is_available()
    if require_gpu and not cuda_available:
        print("[ERROR] --require-gpu set but CUDA unavailable", file=sys.stderr)
        sys.exit(1)
    if not require_gpu and not cuda_available and not (smoke_test and cpu_smoke_ok):
        print("[WARNING] No CUDA detected; full run on CPU will be slow", file=sys.stderr)
    if cuda_available:
        print(f"[GPU OK] {torch.cuda.get_device_name(0)}")


def main():
    args = parse_args()
    gpu_check(args.require_gpu, args.cpu_smoke_ok, args.smoke_test)
    from fever_from_scratch_transformer_experiment import (
        ScratchTransformerConfig,
        format_results_markdown,
        run_experiment,
    )

    cfg = ScratchTransformerConfig(
        num_train_samples=128 if args.smoke_test else args.train_samples,
        num_eval_samples=64 if args.smoke_test else args.eval_samples,
        max_seq_len=args.max_seq_len,
        batch_size=min(8, args.batch_size) if args.smoke_test else args.batch_size,
        num_epochs=1 if args.smoke_test else args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        consistency_loss_weight=args.consistency_loss_weight,
        warmup_steps=min(10, args.warmup_steps) if args.smoke_test else args.warmup_steps,
        d_model=args.d_model,
        n_heads=args.n_heads,
        n_layers=args.n_layers,
        d_ff=args.d_ff,
        dropout=args.dropout,
        seed=args.seed,
        pooling_modes=tuple(["evidence_only_pooling"] if args.smoke_test else args.variants),
        results_path=args.output_csv,
        smoke_test=args.smoke_test,
        require_gpu=args.require_gpu,
    )
    print("[INFO] Starting scratch FEVER run")
    print(cfg)
    df = run_experiment(cfg)
    print(df.to_string(index=False))
    md_path = args.output_csv.replace(".csv", ".md")
    with open(md_path, "w") as f:
        f.write(format_results_markdown(df, cfg))
    print(f"Saved results to {args.output_csv} and {md_path}")


if __name__ == "__main__":
    main()
