import argparse
import csv
import os
import sys
from collections import defaultdict
from typing import List, Optional, Tuple

import torch

from src.args import parse_arguments
from src.eval import eval_single_dataset, eval_ASR
from src.task_vectors import NonLinearTaskVector

TARGETS = {
    "MNISTVal": "7",
    "SVHNVal": "7",
    "CIFAR100Val": "orange",
    "ImageNetVal": "orange",
    "SUN397Val": "river",
    "CarsVal": "Acura RL Sedan 2012",
    "PETSVal": "pug",
    "GTSRBVal": "stop",
    "EuroSATVal": "river",
    "RESISC45Val": "river",
}

# Output columns: NO DatasetVal, NO CA_val/ASR_val
OUT_HEADERS = [
    "DatasetTest","attack","init_CA","init_ASR",
    "alpha","CA_test","ASR_test",
    "alpha_low","alpha_high","CA_test_low","ASR_test_low","CA_test_high","ASR_test_high"
]

def read_rows(path: str) -> List[dict]:
    with open(path, "r", newline="") as f:
        r = csv.DictReader(f)
        rows = [row for row in r]
    for r in rows:
        r["alpha"] = float(r["alpha"])
        r["CA"] = float(r["CA"])
        r["ASR"] = float(r["ASR"])
    return rows

def group_by(rows: List[dict]):
    g = defaultdict(list)  # (Dataset, attack) -> rows
    for r in rows:
        g[(r["Dataset"], r["attack"])].append(r)
    for k in g:
        g[k].sort(key=lambda x: x["alpha"])
    return g

def select_alpha(cands: List[dict],
                 ca_min_frac: Optional[float],
                 asr_max_frac: Optional[float],
                 ca_min_abs: Optional[float],
                 asr_max_abs: Optional[float],
                 prefer_min_asr: bool = False,
                 ) -> Tuple[dict, str, float, float]:
    # baseline from alpha==0 (for constraint checks)
    ca0 = None
    asr0 = None
    for r in cands:
        if abs(r["alpha"]) < 1e-12:
            ca0, asr0 = r["CA"], r["ASR"]
            break
    if ca0 is None:
        ca0, asr0 = cands[0]["CA"], cands[0]["ASR"]

    def key_max_ca(r):   # prefer high CA, then low ASR, then small alpha
        return (-r["CA"], r["ASR"], r["alpha"])

    def key_min_asr(r):  # prefer low ASR, then high CA, then small alpha
        return (r["ASR"], -r["CA"], r["alpha"])

    has_ca  = (ca_min_frac is not None) or (ca_min_abs is not None)
    has_asr = (asr_max_frac is not None) or (asr_max_abs is not None)

    def meets(r):
        ok = True
        if ca_min_frac is not None: ok &= (r["CA"] >= ca_min_frac * ca0)
        if asr_max_frac is not None: ok &= (r["ASR"] <= asr_max_frac * asr0)
        if ca_min_abs  is not None:  ok &= (r["CA"] >= ca_min_abs)
        if asr_max_abs is not None:  ok &= (r["ASR"] <= asr_max_abs)
        return ok

    filt = [r for r in cands if meets(r)]
    if has_ca:
        if prefer_min_asr:
            strategy = "CA>=threshold" + (" AND ASR<=tolerance" if has_asr else "") + " (prefer min ASR)"
            sorter = key_min_asr
        else:
            strategy = "CA>=threshold" + (" AND ASR<=tolerance" if has_asr else "") + " (prefer max CA)"
            sorter = key_max_ca
    elif has_asr:
        strategy = "ASR<=tolerance (prefer min ASR)"
        sorter = key_min_asr
    else:
        strategy = "No constraints; maximizing CA"
        sorter = key_max_ca
        filt = cands

    if filt:
        return sorted(filt, key=sorter)[0], strategy, ca0, asr0

    # fallback: minimal violation
    def viol(r):
        v = 0.0
        if ca_min_frac is not None:
            v += max(0.0, (ca_min_frac*ca0 - r["CA"]) / max(1e-8, ca_min_frac*ca0))
        if asr_max_frac is not None:
            v += max(0.0, (r["ASR"] - asr_max_frac*asr0) / max(1e-8, asr_max_frac*asr0))
        if ca_min_abs is not None:
            v += max(0.0, (ca_min_abs - r["CA"]) / max(1e-8, ca_min_abs))
        if asr_max_abs is not None:
            v += max(0.0, (r["ASR"] - asr_max_abs) / max(1e-8, asr_max_abs))
        return (v, r["ASR"], -r["CA"], r["alpha"])
    return sorted(cands, key=viol)[0], "Fallback: minimal violation", ca0, asr0

def main():
    # Script args
    p = argparse.ArgumentParser(add_help=False)
    p.add_argument("--val-csv", required=True, help="Existing CSV with columns Dataset,attack,alpha,CA,ASR")
    p.add_argument("--out-csv", required=True, help="CSV to append test results")
    p.add_argument("--datasets", default=None, help="Optional comma-separated subset of Dataset values")
    p.add_argument("--ca-min-frac", type=float, default=None)
    p.add_argument("--asr-max-frac", type=float, default=None)
    p.add_argument("--ca-min-abs", type=float, default=None)
    p.add_argument("--asr-max-abs", type=float, default=None)
    p.add_argument("--print_choices", action="store_true")
    p.add_argument(
        "--prefer-min-asr",
        action="store_true",
        help="When a CA tolerance is provided, pick the alpha that MINIMIZES ASR "
             "among candidates that satisfy CA (and optional ASR) constraints; "
             "tie-break by higher CA, then smaller alpha."
    )
    # Sensitivity knob (±10% by default)
    p.add_argument(
        "--alpha-perturb-frac", type=float, default=0.10,
        help="Relative perturbation ±X around chosen alpha to measure sensitivity (e.g., 0.10 for ±10%)."
    )
    my_args, remaining = p.parse_known_args()

    # Project args (attack/model/seed affect checkpoints)
    orig_argv = sys.argv
    try:
        sys.argv = [sys.argv[0]] + remaining
        args = parse_arguments()
    finally:
        sys.argv = orig_argv

    attack_mode = "random" if getattr(args, "attack", "") == "badnet" else getattr(args, "attack", "")
    if getattr(args, "attack", "") == "SIG":
        args.attack = "sig"

    rows = read_rows(my_args.val_csv)
    if my_args.datasets:
        allowed = set(x.strip() for x in my_args.datasets.split(",") if x.strip())
        rows = [r for r in rows if r["Dataset"] in allowed]
    groups = group_by(rows)

    os.makedirs(os.path.dirname(my_args.out_csv), exist_ok=True)
    write_header = not os.path.exists(my_args.out_csv)
    if write_header:
        with open(my_args.out_csv, "w", newline="") as f:
            csv.writer(f).writerow(OUT_HEADERS)

    for (dataset, atk), cands in groups.items():
        if atk != attack_mode:
            continue

        chosen, strategy, ca0, asr0 = select_alpha(
            cands,
            my_args.ca_min_frac, my_args.asr_max_frac,
            my_args.ca_min_abs,  my_args.asr_max_abs,
            prefer_min_asr=my_args.prefer_min_asr,
        )

        alpha = chosen["alpha"]

        # Build checkpoints for this dataset
        args.save = f"checkpoints/{args.model}"
        if args.seed == 0:
            base = f"./checkpoints/ViT-B-32/{args.attack}/1e-5/poison_only/{dataset}"
        else:
            base = f"./checkpoints/ViT-B-32/{args.attack}/1e-5/seed_{args.seed}/poison_only/{dataset}"
        pretrained_checkpoint = f"{base}/zeroshot.pt"
        finetuned_checkpoint  = f"{base}/finetuned.pt"

        # Test dataset name and target
        target = TARGETS.get(dataset)
        test_dataset = dataset.replace("Val", "")

        # Evaluate initial CA and ASR on TEST (pretrained)
        pre_trained_model = torch.load(pretrained_checkpoint)
        init_ca = eval_single_dataset(pre_trained_model, test_dataset, args)["top1"]
        init_asr = eval_ASR(pre_trained_model, test_dataset, args, target=target, mode=attack_mode)

        # Build task vector once
        task_vector = -NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)

        # TEST @ chosen alpha
        model_chosen = task_vector.apply_to(pretrained_checkpoint, scaling_coef=float(alpha))
        ca_test = eval_single_dataset(model_chosen, test_dataset, args)["top1"]
        asr_test = eval_ASR(model_chosen, test_dataset, args, target=target, mode=attack_mode)

        # Sensitivity alphas (± frac)
        frac = my_args.alpha_perturb_frac
        alpha_low  = alpha * (1.0 - frac)
        alpha_high = alpha * (1.0 + frac)

        # TEST @ alpha_low
        model_low = task_vector.apply_to(pretrained_checkpoint, scaling_coef=float(alpha_low))
        ca_test_low = eval_single_dataset(model_low, test_dataset, args)["top1"]
        asr_test_low = eval_ASR(model_low, test_dataset, args, target=target, mode=attack_mode)

        # TEST @ alpha_high
        model_high = task_vector.apply_to(pretrained_checkpoint, scaling_coef=float(alpha_high))
        ca_test_high = eval_single_dataset(model_high, test_dataset, args)["top1"]
        asr_test_high = eval_ASR(model_high, test_dataset, args, target=target, mode=attack_mode)

        if my_args.print_choices:
            print("-" * 80)
            print(f"{dataset} → {test_dataset}")
            print(f"Chosen alpha: {alpha}  | Strategy: {strategy}")
            print(f"Initial TEST (pretrained): CA={init_ca:.4f}  ASR={init_asr:.6f}")
            print(f"TEST @alpha:      CA={ca_test:.4f}  ASR={asr_test:.6f}")
            print(f"Sensitivity (±{frac*100:.1f}% of α):")
            print(f"  TEST @α_low={alpha_low:.6g}:   CA={ca_test_low:.4f}  ASR={asr_test_low:.6f}")
            print(f"  TEST @α_high={alpha_high:.6g}:  CA={ca_test_high:.4f} ASR={asr_test_high:.6f}")

        with open(my_args.out_csv, "a", newline="") as f:
            csv.writer(f).writerow([
                test_dataset, attack_mode, init_ca, init_asr,
                alpha, ca_test, asr_test,
                alpha_low, alpha_high, ca_test_low, asr_test_low, ca_test_high, asr_test_high
            ])

    print("\nDone. Test results appended to:", my_args.out_csv)

if __name__ == "__main__":
    main()
