import argparse
import csv
import os
import sys
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import torch

from src.args import parse_arguments
from src.eval import eval_single_dataset, eval_ASR
from src.task_vectors import NonLinearTaskVector

TARGETS = {
    "MNISTVal": "7",
    "SVHNVal": "7",
    "CIFAR100Val": "orange",
    "ImageNetVal": "orange",   
    "SUN397Val": "river",
    "CarsVal": "Acura RL Sedan 2012",
    "PETSVal": "pug",
    "GTSRBVal": "stop",
    "EuroSATVal": "river",
    "RESISC45Val": "river",
}

OUT_HEADERS = [
    "DatasetVal","DatasetTest","attack","init_CA", "init_ASR","alpha","CA_val","ASR_val","CA_test","ASR_test"
]

def read_rows(path: str) -> List[dict]:
    with open(path, "r", newline="") as f:
        r = csv.DictReader(f)
        rows = [row for row in r]
    # cast
    for r in rows:
        r["alpha"] = float(r["alpha"])
        r["CA"] = float(r["CA"])
        r["ASR"] = float(r["ASR"])
    return rows

def group_by(rows: List[dict]):
    g = defaultdict(list)  # (Dataset, attack) -> rows
    for r in rows:
        g[(r["Dataset"], r["attack"])].append(r)
    for k in g:
        g[k].sort(key=lambda x: x["alpha"])
    return g

def select_alpha(cands: List[dict],
                 ca_min_frac: Optional[float],
                 asr_max_frac: Optional[float],
                 ca_min_abs: Optional[float],
                 asr_max_abs: Optional[float],
                 prefer_min_asr: bool = False,  
                                             ) -> Tuple[dict, str, float, float]:
    # baseline from alpha==0
    ca0 = None
    asr0 = None
    for r in cands:
        if abs(r["alpha"]) < 1e-12:
            ca0, asr0 = r["CA"], r["ASR"]
            break
    if ca0 is None:
        # fallback: first row as baseline
        ca0, asr0 = cands[0]["CA"], cands[0]["ASR"]

    def key_max_ca(r):   # prefer high CA, then low ASR, then small alpha
        return (-r["CA"], r["ASR"], r["alpha"])

    def key_min_asr(r):  # prefer low ASR, then high CA, then small alpha
        return (r["ASR"], -r["CA"], r["alpha"])

    has_ca  = (ca_min_frac is not None) or (ca_min_abs is not None)
    has_asr = (asr_max_frac is not None) or (asr_max_abs is not None)

    # filter as before
    def meets(r):
        ok = True
        if ca_min_frac is not None: ok &= (r["CA"] >= ca_min_frac * ca0)
        if asr_max_frac is not None: ok &= (r["ASR"] <= asr_max_frac * asr0)
        if ca_min_abs  is not None:  ok &= (r["CA"] >= ca_min_abs)
        if asr_max_abs is not None:  ok &= (r["ASR"] <= asr_max_abs)
        return ok
    filt = [r for r in cands if meets(r)]

    if has_ca:
        if prefer_min_asr:
            strategy = "CA>=threshold" + (" AND ASR<=tolerance" if has_asr else "") + " (prefer min ASR)"
            sorter = key_min_asr
        else:
            strategy = "CA>=threshold" + (" AND ASR<=tolerance" if has_asr else "") + " (prefer max CA)"
            sorter = key_max_ca
    elif has_asr:
        strategy = "ASR<=tolerance (prefer min ASR)"
        sorter = key_min_asr
    else:
        strategy = "No constraints; maximizing CA"
        sorter = key_max_ca
        filt = cands

    if filt:
        return sorted(filt, key=sorter)[0], strategy, ca0, asr0

    # fallback: minimal violation
    def viol(r):
        v = 0.0
        if ca_min_frac is not None:
            v += max(0.0, (ca_min_frac*ca0 - r["CA"]) / max(1e-8, ca_min_frac*ca0))
        if asr_max_frac is not None:
            v += max(0.0, (r["ASR"] - asr_max_frac*asr0) / max(1e-8, asr_max_frac*asr0))
        if ca_min_abs is not None:
            v += max(0.0, (ca_min_abs - r["CA"]) / max(1e-8, ca_min_abs))
        if asr_max_abs is not None:
            v += max(0.0, (r["ASR"] - asr_max_abs) / max(1e-8, asr_max_abs))
        return (v, r["ASR"], -r["CA"], r["alpha"])
    return sorted(cands, key=viol)[0], "Fallback: minimal violation", ca0, asr0

def main():
    # Script args
    p = argparse.ArgumentParser(add_help=False)
    p.add_argument("--val-csv", required=True, help="Existing CSV with columns Dataset,attack,alpha,CA,ASR")
    p.add_argument("--out-csv", required=True, help="CSV to append test results")
    p.add_argument("--datasets", default=None, help="Optional comma-separated subset of Dataset values")
    p.add_argument("--ca-min-frac", type=float, default=None)
    p.add_argument("--asr-max-frac", type=float, default=None)
    p.add_argument("--ca-min-abs", type=float, default=None)
    p.add_argument("--asr-max-abs", type=float, default=None)
    p.add_argument("--print_choices", action="store_true")
    # in select_and_eval_test.py, args section
    p.add_argument(
        "--prefer-min-asr",
        action="store_true",
        help="When a CA tolerance is provided, pick the alpha that MINIMIZES ASR "
            "among candidates that satisfy CA (and optional ASR) constraints; "
            "tie-break by higher CA, then smaller alpha."
    )
    my_args, remaining = p.parse_known_args()

    # Project args (attack/model/seed affect checkpoints)
    orig_argv = sys.argv
    try:
        sys.argv = [sys.argv[0]] + remaining
        args = parse_arguments()
    finally:
        sys.argv = orig_argv

    attack_mode = "random" if getattr(args, "attack", "") == "badnet" else getattr(args, "attack", "")
    if getattr(args, "attack", "") == "SIG":
        args.attack = "sig"

    rows = read_rows(my_args.val_csv)
    if my_args.datasets:
        allowed = set(x.strip() for x in my_args.datasets.split(",") if x.strip())
        rows = [r for r in rows if r["Dataset"] in allowed]
    groups = group_by(rows)

    os.makedirs(os.path.dirname(my_args.out_csv), exist_ok=True)
    if not os.path.exists(my_args.out_csv):
        with open(my_args.out_csv, "w", newline="") as f:
            csv.writer(f).writerow(OUT_HEADERS)

    for (dataset, atk), cands in groups.items():
        if atk != attack_mode:
            continue

        chosen, strategy, ca0, asr0 = select_alpha(
            cands,
            my_args.ca_min_frac, my_args.asr_max_frac,
            my_args.ca_min_abs,  my_args.asr_max_abs,
            prefer_min_asr=my_args.prefer_min_asr,   # <— NEW
        )

        alpha = chosen["alpha"]
        ca_val, asr_val = chosen["CA"], chosen["ASR"]

        # Build checkpoints for this dataset
        if args.model == 'convnext_base__pretrained__laion400m_s13b_b51k':
            args.save = f"checkpoints/convnext_base/"
            model_path = "convnext_base"
        else:
            args.save = f"checkpoints/{args.model}"
            model_path = args.model


        if args.seed == 0:
            pretrained_checkpoint = f"./checkpoints/{model_path}/{args.attack}/1e-5/poison_only/{dataset}/zeroshot.pt"
            finetuned_checkpoint  = f"./checkpoints/{model_path}/{args.attack}/1e-5/poison_only/{dataset}/finetuned.pt"
        else:
            pretrained_checkpoint = f"./checkpoints/{model_path}/{args.attack}/1e-5/seed_{args.seed}/poison_only/{dataset}/zeroshot.pt"
            finetuned_checkpoint  = f"./checkpoints/{model_path}/{args.attack}/1e-5/seed_{args.seed}/poison_only/{dataset}/finetuned.pt"


        # Test dataset name and target
        
        target = TARGETS.get(dataset)
        test_dataset = dataset.replace("Val", "")

        # Evaluate initial CA and ASR
        pre_trained_model = torch.load(pretrained_checkpoint)
        init_ca = eval_single_dataset(pre_trained_model, test_dataset, args)["top1"]
        init_asr = eval_ASR(pre_trained_model, test_dataset, args, target=target, mode=attack_mode)

        # Evaluate on test
        task_vector = -NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
        model = task_vector.apply_to(pretrained_checkpoint, scaling_coef=float(alpha))
        ca_test = eval_single_dataset(model, test_dataset, args)["top1"]
        asr_test = eval_ASR(model, test_dataset, args, target=target, mode=attack_mode)

        if my_args.print_choices:
            print("-" * 80)
            print(f"{dataset} → {test_dataset}")
            print(f"Initial CA: {init_ca:.4f}  ASR: {init_asr:.6f}")
            print(f"Chosen alpha: {alpha}  | Strategy: {strategy}")
            print(f"Baseline (alpha=0): CA0={ca0:.4f}  ASR0={asr0:.6f}")
            print(f"VAL @alpha: CA={ca_val:.4f}  ASR={asr_val:.6f}")
            print(f"TEST      : CA={ca_test:.4f} ASR={asr_test:.6f}")

        with open(my_args.out_csv, "a", newline="") as f:
            csv.writer(f).writerow([
                dataset, test_dataset, attack_mode, init_ca, init_asr, alpha, ca_val, asr_val, ca_test, asr_test
            ])

    print("\nDone. Test results appended to:", my_args.out_csv)

if __name__ == "__main__":
    main()
