import json
from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch

from circuit_tracer import ReplacementModel
from circuit_tracer.utils.decode_url_features import decode_url_features


def _normalize_supernode_name(s: str) -> str:
    s = (s or "").strip().lower()
    s = s.replace("+", " ")
    s = s.replace("/", " / ")
    s = " ".join(s.split())
    return s


def extract_nodes_from_supernodes(url: str, names_csv: str) -> Dict[int, List[int]]:
    sn_feats, _ = decode_url_features(url)
    names = [s.strip() for s in (names_csv or "").split(",") if s.strip()]
    norm_to_raw = {_normalize_supernode_name(k): k for k in sn_feats.keys()}
    picked: Dict[int, List[int]] = {}
    for name in names:
        raw = norm_to_raw.get(_normalize_supernode_name(name))
        if raw is None:
            continue
        for f in sn_feats.get(raw, []):
            picked.setdefault(int(f.layer), []).append(int(f.feature_idx))
    for k in list(picked.keys()):
        picked[k] = sorted(list(set(picked[k])))
    return picked


def build_interventions_allseq(model: ReplacementModel, prompt: str, target_features: Dict[int, List[int]], value: float):
    n_pos = len(model.tokenizer(prompt).input_ids)
    return [(int(layer), slice(0, int(n_pos)), int(idx), float(value)) for layer, idxs in target_features.items() for idx in idxs]


def compute_top_outputs(model: ReplacementModel, logits: torch.Tensor, k: int = 5) -> List[Tuple[str, float]]:
    p, ids = logits.squeeze(0)[-1].softmax(-1).topk(k)
    toks = [model.tokenizer.decode(i) for i in ids]
    return list(zip(toks, [float(x) for x in p.tolist()]))


def track_token_stats(model: ReplacementModel, logits: torch.Tensor, track_text: str) -> Optional[Dict[str, Any]]:
    x = logits[0, -1].float()
    p = torch.softmax(x, dim=-1)

    candidates: list[tuple[str, int]] = []
    ids0 = model.tokenizer(track_text, add_special_tokens=False).input_ids
    if ids0:
        candidates.append((track_text, int(ids0[0])))
    if track_text and (not track_text.startswith(" ")):
        ids1 = model.tokenizer(" " + track_text, add_special_tokens=False).input_ids
        if ids1:
            candidates.append((" " + track_text, int(ids1[0])))

    if not candidates:
        return None

    best_track_text, tid = max(candidates, key=lambda t: float(p[t[1]].item()))

    prob = float(p[tid].item())
    logit = float(x[tid].item())
    rank = int((x > x[tid]).sum().item()) + 1
    topv, topi = torch.topk(x, k=2)
    true_top1_tok = model.tokenizer.decode([int(topi[0].item())])
    if int(topi[0].item()) == tid:
        margin = float((x[tid] - topv[1]).item())
        competitor_tok = model.tokenizer.decode([int(topi[1].item())])
    else:
        margin = float((x[tid] - topv[0]).item())
        competitor_tok = true_top1_tok
    tok = model.tokenizer.decode([tid])
    return {
        "track_text": best_track_text,
        "track_token": tok,
        "token_id": tid,
        "prob": prob,
        "logit": logit,
        "rank": rank,
        "margin_vs_top1": margin,
        "true_top1_token": true_top1_tok,
        "top_competitor_token": competitor_tok,
    }


def evaluate_simulatability(
    *,
    model: ReplacementModel,
    base_prompt: str,
    adv_prompt: str,
    supernodes_url: str,
    supernode_names: str,
    ablate_value: float = 0.0,
    track_texts: Optional[Sequence[str]] = None,
    topk: int = 5,
    include_attack_ablate: bool = False,
) -> Dict[str, Any]:
    target_features = extract_nodes_from_supernodes(supernodes_url, supernode_names)
    if not target_features:
        raise ValueError(f"No features found for supernode names: {supernode_names}")

    interventions_base = build_interventions_allseq(model, base_prompt, target_features, ablate_value)
    logits_base, _ = model.get_activations(base_prompt)
    logits_adv, _ = model.get_activations(adv_prompt)
    logits_base_ab, _ = model.feature_intervention(base_prompt, interventions_base, return_activations=False)

    logits_adv_ab = None
    if include_attack_ablate:
        interventions_adv = build_interventions_allseq(model, adv_prompt, target_features, ablate_value)
        logits_adv_ab, _ = model.feature_intervention(adv_prompt, interventions_adv, return_activations=False)

    def _top1(logits: torch.Tensor) -> str:
        x = logits[0, -1].float()
        return str(model.tokenizer.decode([int(torch.argmax(x).item())]))

    out: Dict[str, Any] = {
        "supernodes_url": str(supernodes_url),
        "supernode_names": str(supernode_names),
        "ablate_value": float(ablate_value),
        "prompts": {"base": base_prompt, "adv": adv_prompt},
        "top1": {
            "base": _top1(logits_base),
            "adv": _top1(logits_adv),
            "base_ablate": _top1(logits_base_ab),
        },
        "topk": {
            "base": compute_top_outputs(model, logits_base, k=topk),
            "adv": compute_top_outputs(model, logits_adv, k=topk),
            "base_ablate": compute_top_outputs(model, logits_base_ab, k=topk),
        },
    }

    if logits_adv_ab is not None:
        out["top1"]["adv_ablate"] = _top1(logits_adv_ab)
        out["topk"]["adv_ablate"] = compute_top_outputs(model, logits_adv_ab, k=topk)

    track_texts = list(track_texts) if track_texts is not None else []
    if track_texts:
        out["tracked"] = {}
        for t in track_texts:
            out["tracked"][t] = {
                "base": track_token_stats(model, logits_base, t),
                "adv": track_token_stats(model, logits_adv, t),
                "base_ablate": track_token_stats(model, logits_base_ab, t),
            }
            if logits_adv_ab is not None:
                out["tracked"][t]["adv_ablate"] = track_token_stats(model, logits_adv_ab, t)

        def _delta(a: Optional[Dict[str, Any]], b: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
            if a is None or b is None:
                return None
            return {
                "delta_logit": float(a["logit"] - b["logit"]),
                "delta_prob": float(a["prob"] - b["prob"]),
                "delta_rank": int(a["rank"] - b["rank"]),
                "delta_margin_vs_top1": float(a["margin_vs_top1"] - b["margin_vs_top1"]),
            }

        out["deltas"] = {}
        for t in track_texts:
            tb = out["tracked"][t]["base"]
            ta = out["tracked"][t]["adv"]
            tab = out["tracked"][t]["base_ablate"]
            out["deltas"][t] = {
                "pref_minus_base": _delta(ta, tb),
                "base_ablate_minus_base": _delta(tab, tb),
            }
            if logits_adv_ab is not None:
                taa = out["tracked"][t]["adv_ablate"]
                out["deltas"][t]["adv_ablate_minus_base"] = _delta(taa, tb)

    out["agreements"] = {
        "flip_pref": bool(out["top1"]["adv"] != out["top1"]["base"]),
        "flip_ablate": bool(out["top1"]["base_ablate"] != out["top1"]["base"]),
        "top1_adv_equals_base_ablate": bool(out["top1"]["adv"] == out["top1"]["base_ablate"]),
    }

    return out


def dumps_json(obj: Any) -> str:
    return json.dumps(obj, ensure_ascii=False, indent=2)
