import argparse
import csv
import os
import sys
import time
from typing import List, Optional

import torch

from src.args import parse_arguments
from src.eval import eval_single_dataset, eval_ASR
from src.task_vectors import NonLinearTaskVector


CSV_HEADERS = ["Dataset", "zt_test_CA"]

def ensure_csv(path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    if not os.path.exists(path):
        with open(path, "w", newline="") as f:
            csv.writer(f).writerow(CSV_HEADERS)

def existing_keys(path: str):
    keys = set()
    if not os.path.exists(path):
        return keys
    with open(path, "r", newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            try:
                keys.add((row["Dataset"], row["attack"], float(row["alpha"])))
            except Exception:
                pass
    return keys

def main():
    # Script-specific args first
    p = argparse.ArgumentParser(add_help=False)
    p.add_argument("--datasets", required=True,
                   help="Comma-separated, e.g. CIFAR100Val,ImageNetVal,SUN397Val")
    p.add_argument("--out-csv", required=True, help="Path to CSV to append")
    p.add_argument("--skip-existing", action="store_true",
                   help="Skip rows already present for (Dataset,attack,alpha)")
    my_args, remaining = p.parse_known_args()

    # Project args next (attack/model/seed/etc.)
    orig_argv = sys.argv
    try:
        sys.argv = [sys.argv[0]] + remaining
        args = parse_arguments()
    finally:
        sys.argv = orig_argv

    datasets = [d.strip() for d in my_args.datasets.split(",") if d.strip()]

    ensure_csv(my_args.out_csv)
    seen = existing_keys(my_args.out_csv) if my_args.skip_existing else set()

    for dataset in datasets:
        if 'Val' in dataset:
            dataset = dataset.replace('Val', '')
        
        print(f"Evaluating test set on {dataset}")

        args.save = f"checkpoints/{args.model}"
        if args.seed == 0:
            pretrained_checkpoint = f"./checkpoints/ViT-B-32/{args.attack}/1e-5/{dataset}Val/zeroshot.pt"
        else:
            pretrained_checkpoint = f"./checkpoints/ViT-B-32/{args.attack}/1e-5/seed_{args.seed}/{dataset}Val/zeroshot.pt"

        pre_trained_model = torch.load(pretrained_checkpoint)

        ca_zt = eval_single_dataset(pre_trained_model, dataset, args)["top1"]
        
        with open(my_args.out_csv, "a", newline="") as f:
            csv.writer(f).writerow([dataset, ca_zt])
        print(f"alpha=0  CA={ca_zt:.4f}")


    print("\nDone. Appended to:", my_args.out_csv)

if __name__ == "__main__":
    main()
