import torch
import sys
import os
import json
import time
from pathlib import Path
from transformers import AutoTokenizer
import argparse
import numpy as np
from circuit_tracer import ReplacementModel
from circuit_tracer.utils.decode_url_features import decode_url_features
from sim_eval import evaluate_simulatability
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

try:
    from nanogcg.nanogcg_custom import GCGConfig, run
    from nanogcg.wrappers import FeatureAttackWrapper
except ImportError as e:
    print("Import error: please ensure nanogcg_custom.py and wrappers.py are under the nanogcg/ folder")
    print(f"Details: {e}")
    sys.exit(1)

def extract_nodes_from_graph(graph):
    targets = {}
    if hasattr(graph, 'active_features'):
        features_data = graph.active_features.tolist()
    elif hasattr(graph, 'real_features'):
        features_data = graph.real_features
    else:
        print("Error: Graph object does not have 'active_features' or 'real_features'.")
        print(f"Available attributes: {[d for d in dir(graph) if not d.startswith('__')]}")
        return {}
    for node_info in features_data:
        if len(node_info) < 3:
            continue
        layer = int(node_info[0])
        feature_idx = int(node_info[2])
        if layer not in targets:
            targets[layer] = []
        if feature_idx not in targets[layer]:
            targets[layer].append(feature_idx)
    return targets

def _infer_d_transcoder(model) -> int:
    if hasattr(model, "transcoders") and hasattr(model.transcoders, "d_transcoder"):
        return int(model.transcoders.d_transcoder)
    if hasattr(model, "transcoders"):
        try:
            first = model.transcoders[0]
            if hasattr(first, "d_transcoder"):
                return int(first.d_transcoder)
        except Exception:
            pass
    raise RuntimeError("Unable to infer d_transcoder from model.transcoders")

def _sample_background_features(
    *,
    d_transcoder: int,
    target_features: dict[int, list[int]],
    per_layer: int,
    seed: int | None,
) -> dict[int, list[int]]:
    rng = np.random.default_rng(seed)
    bg: dict[int, list[int]] = {}
    for layer, sel in target_features.items():
        sel_set = set(int(i) for i in sel)
        candidates = np.array([i for i in range(d_transcoder) if i not in sel_set], dtype=np.int64)
        if candidates.size == 0:
            continue
        k = int(min(per_layer, candidates.size))
        picked = rng.choice(candidates, size=k, replace=False)
        bg[int(layer)] = sorted(int(x) for x in picked.tolist())
    return bg

def _normalize_supernode_name(s: str) -> str:
    s = (s or "").strip().lower()
    s = s.replace("+", " ")
    s = s.replace("/", " / ")
    s = " ".join(s.split())
    return s

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    p = argparse.ArgumentParser()
    p.add_argument("--model", default="google/gemma-2-2b")
    p.add_argument("--transcoder", default="gemma")
    p.add_argument("--circuit", default=None, help="Path to a saved circuit Graph (.pt). Required only if --supernodes-url/--supernode-names are not provided.")
    p.add_argument("--template", default="{optim_str}Fact: The capital of the state containing Dallas is")
    p.add_argument("--prompts-file", default=None, help="Optional path to a text file with one prompt per line. If provided, runs batch mode.")
    p.add_argument("--prompts-start", type=int, default=0)
    p.add_argument("--prompts-limit", type=int, default=None)
    p.add_argument("--steps", type=int, default=50)
    p.add_argument("--search-width", type=int, default=32)
    p.add_argument("--batch-size", type=int, default=32)
    p.add_argument("--topk", type=int, default=64)
    p.add_argument("--optim-str-init", default="! ! ! ! ! ! ! ! ! !")
    p.add_argument("--allow-non-ascii", action="store_true")
    p.add_argument("--no-filter-ids", action="store_true")
    p.add_argument("--attack-objective", choices=["repel", "match"], default="repel")
    p.add_argument("--intervention-value", type=float, default=0.0)
    p.add_argument("--kl-lambda", type=float, default=0.0)
    p.add_argument("--entropy-lambda", type=float, default=0.0)
    p.add_argument("--supernodes-url", default=None)
    p.add_argument("--supernode-names", default=None)
    p.add_argument("--protect-lambda", type=float, default=0.5)
    p.add_argument("--bg-per-layer", type=int, default=64)
    p.add_argument("--bg-seed", type=int, default=0)
    p.add_argument("--bg-ratio-lambda", type=float, default=0.0)
    p.add_argument("--aggregation", choices=["max", "mean", "topk_mean"], default="topk_mean")
    p.add_argument("--topk-frac", type=float, default=0.2)
    p.add_argument("--no-strict-hooks", action="store_true")
    p.add_argument("--outdir", default="outputs")
    p.add_argument("--repeat", type=int, default=1, help="Repeat GCG optimization N times for the same --template (single-prompt mode only).")
    p.add_argument("--runs-jsonl", default="runs.jsonl", help="Filename (under --outdir) to append per-run records (jsonl).")
    p.add_argument("--sim-eval", action="store_true")
    p.add_argument("--sim-eval-topk", type=int, default=5)
    p.add_argument("--sim-eval-track-texts", default=None, help="Comma-separated list of single-token strings to track (e.g. ' Austin, Texas')")
    p.add_argument("--sim-eval-include-attack-ablate", action="store_true")
    args = p.parse_args()
    circuit_file_path = args.circuit
    model_name = args.model
    transcoder_name = args.transcoder

    target_features = None
    if args.supernodes_url and args.supernode_names:
        sn_feats, singletons = decode_url_features(args.supernodes_url)
        names = [s.strip() for s in args.supernode_names.split(",") if s.strip()]
        norm_to_raw = {_normalize_supernode_name(k): k for k in sn_feats.keys()}
        picked = {}
        for name in names:
            raw = norm_to_raw.get(_normalize_supernode_name(name))
            if raw is None:
                continue
            for f in sn_feats.get(raw, []):
                picked.setdefault(int(f.layer), []).append(int(f.feature_idx))
        for k in picked:
            picked[k] = sorted(list(set(picked[k])))
        if not picked:
            available = sorted(sn_feats.keys())
            print("supernode_names_not_found")
            print(f"requested={names}")
            print(f"available={available}")
            raise ValueError("No features matched the provided --supernode-names; refusing to fall back to --circuit features")
        target_features = picked
    else:
        circuit_file_path = circuit_file_path or "circuits/my_final_circuit.pt"
        graph_path = Path(str(circuit_file_path))
        if not graph_path.exists():
            print(f"Error: Graph file not found: {graph_path}")
            print("Provide --circuit, or use --supernodes-url + --supernode-names.")
            return
        print(f"Loading Graph from {graph_path}...")
        graph = torch.load(graph_path, weights_only=False)
        all_features = extract_nodes_from_graph(graph)
        target_features = dict(all_features)
        if not target_features:
            print("No target features extracted; exiting.")
            return

    print(f"Loading model: {model_name}...")
    model = ReplacementModel.from_pretrained(
        model_name, 
        transcoder_name, 
        dtype=torch.float16 if device=="cuda" else torch.float32,
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
    )
    if tokenizer.pad_token is None: 
        tokenizer.pad_token = tokenizer.eos_token

    d_transcoder = _infer_d_transcoder(model)
    background_features = _sample_background_features(
        d_transcoder=d_transcoder,
        target_features=target_features,
        per_layer=args.bg_per_layer,
        seed=args.bg_seed,
    )
    n_target = sum(len(v) for v in target_features.values())
    n_bg = sum(len(v) for v in background_features.values())
    print(f"d_transcoder={d_transcoder} | target_features={n_target} | background_features={n_bg}")

    print("Wrapping model for feature attack...")
    attack_model = FeatureAttackWrapper(
        model,
        target_features,
        background_features,
        aggregation=args.aggregation,
        topk_frac=args.topk_frac,
        strict_hooks=(not args.no_strict_hooks),
    )

    config = GCGConfig(
        num_steps=args.steps,
        search_width=args.search_width,
        batch_size=args.batch_size,
        topk=args.topk,
        optim_str_init=args.optim_str_init,
        verbosity="INFO",
        allow_non_ascii=bool(args.allow_non_ascii),
        filter_ids=(not bool(args.no_filter_ids)),
        attack_objective=args.attack_objective,
        intervention_value=args.intervention_value,
        kl_lambda=float(args.kl_lambda),
        entropy_lambda=float(args.entropy_lambda),
        protect_lambda=args.protect_lambda,
        bg_ratio_lambda=args.bg_ratio_lambda,
    )

    attack_prompt = args.template
    def _build_gcg_template(attack_prompt_local: str) -> str:
        gcg_messages = [{"role": "user", "content": attack_prompt_local}]
        try:
            if getattr(tokenizer, "chat_template", None):
                gcg_template_local = tokenizer.apply_chat_template(gcg_messages, tokenize=False, add_generation_prompt=True)
                if tokenizer.bos_token and gcg_template_local.startswith(tokenizer.bos_token):
                    gcg_template_local = gcg_template_local.replace(tokenizer.bos_token, "")
            else:
                gcg_template_local = attack_prompt_local
        except Exception as e:
            print(f"[WARN] tokenizer.apply_chat_template failed; falling back to raw template. err={e}")
            gcg_template_local = attack_prompt_local
        if "{optim_str}" not in gcg_template_local:
            gcg_template_local = gcg_template_local + "{optim_str}"
        return gcg_template_local

    def _parse_track_texts(csv: str | None):
        if not csv:
            return None
        xs = [s for s in str(csv).split(",") if s]
        return xs if xs else None

    def _run_one(attack_prompt_local: str, outdir_local: str | None = None):
        gcg_template_local = _build_gcg_template(attack_prompt_local)
        if "{optim_str}" not in gcg_template_local:
            raise ValueError("gcg template missing {optim_str} placeholder")

        print(f"Starting GCG attack...")
        if args.attack_objective == "match":
            print(
                f"Goal: Make target features approach feature_intervention(value={args.intervention_value}) state, "
                f"while keeping background features unchanged (protect_lambda={args.protect_lambda}, bg_ratio_lambda={args.bg_ratio_lambda})"
            )
        else:
            print(
                f"Goal: Maximize target features' deviation from reference, while suppressing background features' disturbance "
                f"(protect_lambda={args.protect_lambda}, bg_ratio_lambda={args.bg_ratio_lambda})"
            )

        result = run(
            model=attack_model,
            tokenizer=tokenizer,
            messages=[{"role": "user", "content": attack_prompt_local}],
            target="",
            config=config,
        )

        print("\n" + "="*40)
        print("Attack complete")
        print("="*40)
        print(f"Best loss: {result.best_loss:.4f}")
        print(f"Adversarial suffix: {result.best_string}")

        final_input_local = gcg_template_local.replace("{optim_str}", result.best_string)
        print(f"Complete adversarial input: \"{final_input_local}\"")

        before_str_local, after_str_local = gcg_template_local.split("{optim_str}")
        base_text_local = before_str_local + after_str_local

        sim_out = None
        if bool(args.sim_eval) or bool(args.prompts_file):
            if not (args.supernodes_url and args.supernode_names):
                raise ValueError("sim-eval requires --supernodes-url and --supernode-names")
            track_texts = _parse_track_texts(args.sim_eval_track_texts)
            sim_out = evaluate_simulatability(
                model=model,
                base_prompt=base_text_local,
                adv_prompt=final_input_local,
                supernodes_url=str(args.supernodes_url),
                supernode_names=str(args.supernode_names),
                ablate_value=float(args.intervention_value),
                track_texts=track_texts,
                topk=int(args.sim_eval_topk),
                include_attack_ablate=bool(args.sim_eval_include_attack_ablate),
            )
            sim_out["gcg"] = {"best_loss": float(result.best_loss), "suffix": str(result.best_string)}
            if outdir_local:
                os.makedirs(outdir_local, exist_ok=True)
                sim_json = os.path.join(outdir_local, "sim_eval.json")
                with open(sim_json, "w", encoding="utf-8") as f:
                    json.dump(sim_out, f, ensure_ascii=False, indent=2)
                print(f"Saved: {sim_json}")

        return {"result": result, "base_text": base_text_local, "adv_text": final_input_local, "sim_eval": sim_out}

    if args.prompts_file:
        os.makedirs(args.outdir, exist_ok=True)
        with open(str(args.prompts_file), "r", encoding="utf-8") as f:
            lines = [ln.strip() for ln in f.read().splitlines()]
        prompts = [ln for ln in lines if ln]
        start = int(args.prompts_start)
        end = len(prompts)
        if args.prompts_limit is not None:
            end = min(end, start + int(args.prompts_limit))
        prompts = prompts[start:end]
        if not prompts:
            raise ValueError("No prompts loaded from --prompts-file (after applying start/limit)")

        out_jsonl = os.path.join(args.outdir, "sim_eval.jsonl")
        runs_jsonl = os.path.join(args.outdir, str(args.runs_jsonl))
        with open(out_jsonl, "w", encoding="utf-8") as jf, open(runs_jsonl, "a", encoding="utf-8") as rjf:
            for i, ptxt in enumerate(prompts, start=start):
                if "{optim_str}" in ptxt:
                    tpl = ptxt
                else:
                    tpl = "{optim_str} " + ptxt
                subdir = os.path.join(args.outdir, f"prompt_{i:04d}")
                run_out = _run_one(tpl, outdir_local=subdir)
                rec = {
                    "prompt_index": int(i),
                    "template": str(tpl),
                    "base_text": str(run_out["base_text"]),
                    "adv_text": str(run_out["adv_text"]),
                    "gcg": {
                        "best_loss": float(run_out["result"].best_loss),
                        "suffix": str(run_out["result"].best_string),
                    },
                    "sim_eval": run_out["sim_eval"],
                }
                jf.write(json.dumps(rec, ensure_ascii=False) + "\n")
                jf.flush()
                run_rec = {
                    "time": float(time.time()),
                    "mode": "prompts_file",
                    "outdir": str(args.outdir),
                    "run_dir": str(subdir),
                    "prompt_index": int(i),
                    "template": str(tpl),
                    "base_text": str(run_out["base_text"]),
                    "adv_text": str(run_out["adv_text"]),
                    "gcg": {
                        "best_loss": float(run_out["result"].best_loss),
                        "suffix": str(run_out["result"].best_string),
                    },
                    "sim_eval": run_out["sim_eval"],
                }
                rjf.write(json.dumps(run_rec, ensure_ascii=False) + "\n")
                rjf.flush()
        print(f"Saved: {out_jsonl}")
        return

    os.makedirs(args.outdir, exist_ok=True)
    runs_jsonl = os.path.join(args.outdir, str(args.runs_jsonl))
    n_repeat = max(1, int(args.repeat))

    out_main = None
    for r in range(n_repeat):
        subdir = os.path.join(args.outdir, f"run_{r:04d}") if n_repeat > 1 else str(args.outdir)
        out_main = _run_one(attack_prompt, outdir_local=subdir)
        run_rec = {
            "time": float(time.time()),
            "mode": "single_prompt",
            "outdir": str(args.outdir),
            "run_dir": str(subdir),
            "repeat_index": int(r),
            "template": str(args.template),
            "base_text": str(out_main["base_text"]),
            "adv_text": str(out_main["adv_text"]),
            "gcg": {
                "best_loss": float(out_main["result"].best_loss),
                "suffix": str(out_main["result"].best_string),
            },
            "sim_eval": out_main["sim_eval"],
        }
        with open(runs_jsonl, "a", encoding="utf-8") as rjf:
            rjf.write(json.dumps(run_rec, ensure_ascii=False) + "\n")
            rjf.flush()

if __name__ == "__main__":
    main()
