from __future__ import annotations

import argparse
import sys


def parse_args():
    parser = argparse.ArgumentParser(description="GPU-only runner for pretrained GPT-2 FEVER experiment", formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("--require-gpu", action="store_true", default=False)
    parser.add_argument("--cpu-smoke-ok", action="store_true", default=False)
    parser.add_argument("--model-name", type=str, default="gpt2")
    parser.add_argument("--train-samples", type=int, default=50_000)
    parser.add_argument("--eval-samples", type=int, default=5_000)
    parser.add_argument("--max-seq-len", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--consistency-loss-weight", type=float, default=0.5)
    parser.add_argument("--freeze-lower-layers-epochs", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--variants", nargs="+", default=["no_consistency_loss", "evidence_only_pooling", "evidence_only_strict", "full_sequence_pooling", "claim_only_pooling", "evidence_only_random_labels"])
    parser.add_argument("--output-csv", type=str, default="results_fever_pretrained_gpu.csv")
    parser.add_argument("--smoke-test", action="store_true", default=False)
    return parser.parse_args()


def gpu_check(require_gpu: bool, cpu_smoke_ok: bool, smoke_test: bool):
    import torch
    cuda_available = torch.cuda.is_available()
    if require_gpu and not cuda_available:
        print("[ERROR] --require-gpu is set but CUDA is unavailable.", file=sys.stderr)
        sys.exit(1)
    if not require_gpu and not cuda_available and not (smoke_test and cpu_smoke_ok):
        print("[WARNING] No CUDA GPU detected; full training on CPU will be very slow.")


def main():
    args = parse_args()
    gpu_check(args.require_gpu, args.cpu_smoke_ok, args.smoke_test)
    from fever_pretrained_gpt2_experiment import PretrainedGPT2Config, run_pretrained_gpt2_experiment, format_results_markdown, build_pretrained_comparison_table
    cfg = PretrainedGPT2Config(
        model_name=args.model_name,
        num_train_samples=32 if args.smoke_test else args.train_samples,
        num_eval_samples=16 if args.smoke_test else args.eval_samples,
        max_seq_len=args.max_seq_len,
        num_epochs=1 if args.smoke_test else args.epochs,
        batch_size=min(args.batch_size, 8) if args.smoke_test else args.batch_size,
        lr=args.lr,
        consistency_loss_weight=args.consistency_loss_weight,
        freeze_lower_layers_epochs=0 if args.smoke_test else args.freeze_lower_layers_epochs,
        pooling_modes=("evidence_only_strict",) if args.smoke_test else tuple(args.variants),
        results_path=args.output_csv,
        smoke_test=args.smoke_test,
        require_gpu=args.require_gpu,
        seed=args.seed,
    )
    df = run_pretrained_gpt2_experiment(cfg)
    print(df.to_string())
    md_path = args.output_csv.replace(".csv", ".md")
    with open(md_path, "w") as f:
        f.write(format_results_markdown(df, cfg))
    result = build_pretrained_comparison_table(df, "results_comparison_hard.csv")
    if result is not None:
        comp_df, comp_md = result
        comp_df.to_csv("results_fever_pretrained_vs_synthetic.csv", index=False)
        with open("results_fever_pretrained_vs_synthetic.md", "w") as f:
            f.write(comp_md)


if __name__ == "__main__":
    main()
