import argparse
import csv
import os
import sys
import time
from typing import List, Optional

import torch

from src.args import parse_arguments
from src.eval import eval_single_dataset, eval_ASR
from src.task_vectors import NonLinearTaskVector

TARGETS = {
    "MNISTVal": "7",
    "SVHNVal": "7",
    "CIFAR100Val": "orange",
    "ImageNetVal": "orange",
    "SUN397Val": "river",
    "CarsVal": "Acura RL Sedan 2012",
    "PETSVal": "pug",
    "GTSRBVal": "stop",
    "EuroSATVal": "river",
    "RESISC45Val": "river",
}

CSV_HEADERS = ["Dataset", "attack", "alpha", "CA", "ASR"]

def ensure_csv(path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    if not os.path.exists(path):
        with open(path, "w", newline="") as f:
            csv.writer(f).writerow(CSV_HEADERS)

def existing_keys(path: str):
    keys = set()
    if not os.path.exists(path):
        return keys
    with open(path, "r", newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            try:
                keys.add((row["Dataset"], row["attack"], float(row["alpha"])))
            except Exception:
                pass
    return keys

def parse_alpha_grid(alpha_grid: Optional[str],
                     alpha_start: Optional[float],
                     alpha_stop: Optional[float],
                     alpha_step: Optional[float]) -> List[float]:
    if alpha_grid:
        vals = [float(x) for x in alpha_grid.split(",") if x.strip() != ""]
        if not vals:
            raise ValueError("Empty --alpha-grid.")
        return vals
    if alpha_start is None or alpha_stop is None or alpha_step is None:
        raise ValueError("Provide --alpha-grid OR --alpha-start/--alpha-stop/--alpha-step.")
    if alpha_step <= 0:
        raise ValueError("--alpha-step must be > 0.")
    out, x = [], alpha_start
    while x <= alpha_stop + 1e-12:
        out.append(round(x, 10))
        x += alpha_step
    return out

def main():
    # Script-specific args first
    p = argparse.ArgumentParser(add_help=False)
    p.add_argument("--datasets", required=True,
                   help="Comma-separated, e.g. CIFAR100Val,ImageNetVal,SUN397Val")
    p.add_argument("--alpha-grid", default=None,
                   help="Comma-separated floats, e.g. 1,1.1,1.2,2,3")
    p.add_argument("--alpha-start", type=float, default=None)
    p.add_argument("--alpha-stop", type=float, default=None)
    p.add_argument("--alpha-step", type=float, default=None)
    p.add_argument("--out-csv", required=True, help="Path to CSV to append")
    p.add_argument("--skip-existing", action="store_true",
                   help="Skip rows already present for (Dataset,attack,alpha)")
    my_args, remaining = p.parse_known_args()

    # Project args next (attack/model/seed/etc.)
    orig_argv = sys.argv
    try:
        sys.argv = [sys.argv[0]] + remaining
        args = parse_arguments()
    finally:
        sys.argv = orig_argv

    attack_mode = "random" if getattr(args, "attack", "") == "badnet" else getattr(args, "attack", "")
    if getattr(args, "attack", "") == "SIG":
        args.attack = "sig"

    datasets = [d.strip() for d in my_args.datasets.split(",") if d.strip()]
    alphas = parse_alpha_grid(my_args.alpha_grid, my_args.alpha_start, my_args.alpha_stop, my_args.alpha_step)

    ensure_csv(my_args.out_csv)
    seen = existing_keys(my_args.out_csv) if my_args.skip_existing else set()

    for dataset in datasets:
        if dataset not in TARGETS:
            raise ValueError(f"Missing target for {dataset} in TARGETS.")
        print("=" * 80)
        print(f"[VAL] {dataset}")

        # args.save = f"checkpoints/{args.model}"
        args.save = "checkpoints/convnext_base"
        if args.seed == 0:
            pretrained_checkpoint = f"./checkpoints/convnext_base/{args.attack}/1e-5/poison_only/{dataset}/zeroshot.pt"
            finetuned_checkpoint  = f"./checkpoints/convnext_base/{args.attack}/1e-5/poison_only/{dataset}/finetuned.pt"
        else:
            pretrained_checkpoint = f"./checkpoints/convnext_base/{args.attack}/1e-5/seed_{args.seed}/poison_only/{dataset}/zeroshot.pt"
            finetuned_checkpoint  = f"./checkpoints/convnext_base/{args.attack}/1e-5/seed_{args.seed}/poison_only/{dataset}/finetuned.pt"

        pre_trained_model = torch.load(pretrained_checkpoint)
        task_vector = -NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
        target = TARGETS[dataset]

        # Write base alpha=0 row
        base_key = (dataset, attack_mode, 0.0)
        if (not my_args.skip_existing) or (base_key not in seen):
            ca0 = eval_single_dataset(pre_trained_model, dataset, args)["top1"]
            asr0 = eval_ASR(pre_trained_model, dataset, args, target=target, mode=attack_mode)
            with open(my_args.out_csv, "a", newline="") as f:
                csv.writer(f).writerow([dataset, attack_mode, 0.0, ca0, asr0])
            print(f"alpha=0  CA={ca0:.4f}  ASR={asr0:.6f}")

        # Sweep alphas
        for a in alphas:
            key = (dataset, attack_mode, float(a))
            if my_args.skip_existing and key in seen:
                print(f"skip existing alpha={a}")
                continue
            model = task_vector.apply_to(pretrained_checkpoint, scaling_coef=float(a))
            ca = eval_single_dataset(model, dataset, args)["top1"]
            asr = eval_ASR(model, dataset, args, target=target, mode=attack_mode)
            with open(my_args.out_csv, "a", newline="") as f:
                csv.writer(f).writerow([dataset, attack_mode, float(a), ca, asr])
            print(f"alpha={a:<6} CA={ca:.4f}  ASR={asr:.6f}")

    print("\nDone. Appended to:", my_args.out_csv)

if __name__ == "__main__":
    main()
