# filename: output_score_with_entropy_confidence.py
import gc
import argparse
import os
import os.path
import re
import json
import glob
from typing import Dict, Any

import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import (
    get_features_by_layers,
    get_sae,
    cache_logit_lens,
    _swap_layer_number_in_path,
)
from sae_utils import AmlifySAEHook


# ---------------------------
# Helpers (repo-root detection, tags, cache path)
# ---------------------------
def _detect_repo_root_from_features(features_file_path: str) -> str:
    norm = os.path.normpath(features_file_path)
    parts = norm.split(os.sep)
    if "saes-are-good-for-steering" in parts:
        idx = parts.index("saes-are-good-for-steering")
        return os.sep.join(parts[: idx + 1])
    return os.getcwd()


def _method_tag_from_trainer_class(trainer_class_name: str) -> str:
    name = (trainer_class_name or "").lower()
    if "gated" in name:
        return "gated"
    if "batchtopk" in name or "batch_topk" in name:
        return "batch_topk"
    if "topk" in name and "batch" not in name:
        return "topk"
    if "jumprelu" in name or ("jump" in name and "relu" in name):
        return "jump_relu"
    if "standard" in name and "april" in name and "update" in name:
        return "standard_april_update"
    name = re.sub(r"[^a-z0-9_]+", "_", name).strip("_")
    return name[:32] if name else "unknown"


def _trainer_id_from_path(path: str) -> str:
    patterns = [
        r"(?:^|/)(?:trainer|gated|batch[_-]?topk|top[_-]?k|jump[_-]?relu|jumprelu|standard[_-]?april[_-]?update)[_/ -]?(\d+)(?:/|$)",
        r"[_-](\d+)(?:/|$)",
    ]
    for pat in patterns:
        ms = list(re.finditer(pat, path, flags=re.IGNORECASE))
        if ms:
            return ms[-1].group(1)
    return "unk"


def _build_experiment_name(trainer_class: str, trainer_id: str) -> str:
    method = _method_tag_from_trainer_class(trainer_class)
    return trainer_id if trainer_id.startswith(method + "_") else f"{method}_{trainer_id}"


def _auto_cache_path(
    repo_root: str,
    model_type: str,
    layer: int,
    trainer_class: str,
    trainer_id: str,
    topk_for_filename: int,
    k_conf: int,
) -> str:
    """
    Auto path like:
    {repo_root}/cache/results_entropy_score/{model_type}/layer{layer}/{exp}/output_scores_plus_top{K}_kconf{k}.json
    """
    cache_root = os.path.join(repo_root, "cache", "results_entropy_score")
    layer_dir = f"layer{int(layer)}"
    exp_name = _build_experiment_name(trainer_class, trainer_id)
    custom_dir = f"amp10_top_1"
    model_dir = model_type.split("/")[-1]  # NEW: use repo id tail to avoid extra directory
    cache_dir = os.path.join(cache_root, custom_dir, model_dir, layer_dir, exp_name)

    os.makedirs(cache_dir, exist_ok=True)
    filename = f"output_scores_plus_top{int(topk_for_filename)}_kconf{int(k_conf)}.json"
    return os.path.join(cache_dir, filename)


# ---------------------------
# Local snapshot resolver for offline loading (parity with main script)
# ---------------------------
def _resolve_local_snapshot_if_available(repo_or_path: str, cache_root: str, offline: bool) -> str:
    """
    If `repo_or_path` is a repo id like 'Qwen/Qwen2.5-3B(-Instruct)' and offline=True,
    try to locate the latest local snapshot under <cache_root>/hub/models--{repo}/snapshots/<hash>.
    If found, return that snapshot path; else return `repo_or_path` unchanged.
    """
    if os.path.isdir(repo_or_path):
        return repo_or_path
    if not offline:
        return repo_or_path

    base = os.path.join(cache_root, "hub", f"models--{repo_or_path.replace('/', '--')}")
    snap_root = os.path.join(base, "snapshots")
    if os.path.isdir(snap_root):
        candidates = [p for p in glob.glob(os.path.join(snap_root, "*")) if os.path.isdir(p)]
        if candidates:
            candidates.sort(key=os.path.getmtime, reverse=True)
            chosen = candidates[0]
            print(f"[INFO] Using local snapshot for {repo_or_path}: {chosen}")
            return chosen
        else:
            print(f"[WARN] No snapshots found under: {snap_root}")
    else:
        print(f"[WARN] Snapshot root not found: {snap_root}")
    return repo_or_path


# ---------------------------
# Metric functions
# ---------------------------
def _entropy_from_probs(probs: torch.Tensor) -> float:
    """Natural-log entropy over full vocab, probs shape [V]."""
    eps = 1e-12
    x = probs.clamp(min=eps)
    return float(-(x * x.log()).sum().item())


def _token_confidence_topk(probs: torch.Tensor, k: int) -> float:
    """-mean(log top-k probs)."""
    k = int(max(1, min(k, probs.numel())))
    topk_vals, _ = torch.topk(probs, k)
    eps = 1e-12
    return float(-(topk_vals.clamp(min=eps).log().mean().item()))


def _rank_weighted_prob(probs: torch.Tensor, indices: torch.Tensor) -> float:
    """
    P(M) in paper: (1 - rank_min/|V|) * max_prob among representative tokens.
    probs: [V], indices: [K] of representative tokens (logit-lens)
    """
    vocab_size = probs.shape[0]
    tokens_argsort = torch.argsort(probs, dim=0, descending=True)
    ranks = [(tokens_argsort == idx).nonzero(as_tuple=True)[0].item() for idx in indices]
    min_rank = min(ranks) if len(ranks) > 0 else vocab_size - 1
    top_prob = probs[indices].max().item() if len(indices) > 0 else 0.0
    return (1.0 - (min_rank / vocab_size)) * top_prob


# ---------------------------
# Core: compute all metrics with two passes (baseline & intervention)
# ---------------------------
@torch.inference_mode()
def compute_all_metrics(
    layer: int,
    feature: int,
    logit_lens_indices,
    sentence: str,
    sae,
    tokenizer,
    model,
    device: str,
    amp_factor: int = 10,
    k_conf: int = 3,
) -> Dict[str, Any]:
    model.eval()
    model = model.to(device)
    sae = sae.to(device)

    # Tokenize once on device
    inputs = tokenizer(sentence, return_tensors="pt")
    for k, v in inputs.items():
        inputs[k] = v.to(device)

    # -------- Baseline pass (no hook) --------
    outputs_base = model(**inputs)
    logits_base = outputs_base.logits[:, -1, :].squeeze(0)  # [V]
    probs_base = torch.softmax(logits_base, dim=-1).detach().cpu()

    H_base = _entropy_from_probs(probs_base)
    C_base = _token_confidence_topk(probs_base, k_conf)
    P_base = _rank_weighted_prob(probs_base, torch.tensor(logit_lens_indices, dtype=torch.long))

    # -------- Intervention pass (with SAE hook) --------
    sae_hook = AmlifySAEHook(layer, sae, [feature], amp_factor, device)
    model_block_to_hook = model.model.layers[layer]
    handle = model_block_to_hook.register_forward_hook(sae_hook, always_call=True)

    outputs_int = model(**inputs)
    handle.remove()
    # safety: remove any residual hooks
    for hook in list(model_block_to_hook._forward_hooks.values()):
        try:
            hook.remove()
        except Exception:
            pass

    logits_int = outputs_int.logits[:, -1, :].squeeze(0)  # [V]
    probs_int = torch.softmax(logits_int, dim=-1).detach().cpu()

    H_int = _entropy_from_probs(probs_int)
    C_int = _token_confidence_topk(probs_int, k_conf)
    P_int = _rank_weighted_prob(probs_int, torch.tensor(logit_lens_indices, dtype=torch.long))

    # deltas
    dH = H_int - H_base
    dC = C_int - C_base
    Sout = P_int - P_base

    # cleanup
    for k, v in inputs.items():
        inputs[k] = v.cpu()
    model = model.cpu()
    sae = sae.cpu()
    torch.cuda.empty_cache()
    gc.collect()

    return {
        "P_base": float(P_base),
        "P_int": float(P_int),
        "score": float(Sout),
        "entropy_base": float(H_base),
        "entropy_int": float(H_int),
        "delta_entropy": float(dH),
        "confidence_base": float(C_base),
        "confidence_int": float(C_int),
        "delta_confidence": float(dC),
    }


# ---------------------------
# CLI & main loop
# ---------------------------
def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_type", type=str, default="gemma2_2b")
    ap.add_argument("--logit_lens_top_k", type=int, default=20)
    ap.add_argument("--confidence_top_k", type=int, default=3, help="k for Token Confidence (top-k)")
    ap.add_argument("--features_file", type=str, required=True)
    ap.add_argument("--amp_factor", type=int, default=10, help="SAE feature amplification factor")
    ap.add_argument(
        "--device", type=str, default=None,
        help="cpu | cuda | cuda:0 | cuda:1; default picks cuda:0 if available"
    )
    ap.add_argument(
        "--cache_path", type=str, default=None, required=False,
        help="Optional combined JSON path; if omitted, auto path per-layer will be used."
    )
    ap.add_argument(
        "--dl_local_dir", type=str, required=True,
        help="Local DL-SAE folder (contains config.json + ae.pt); layer number will be auto-swapped."
    )
    ap.add_argument(
        "--neutral_sentence", type=str, default="From my experience,",
        help="Neutral prefix sentence used for the next-token distribution."
    )
    return ap.parse_args()


def main():
    args = parse_args()
    model_type = args.model_type
    logit_lens_k = int(args.logit_lens_top_k)
    k_conf = int(args.confidence_top_k)
    device = args.device if args.device is not None else ("cuda:0" if torch.cuda.is_available() else "cpu")

    # Map model_type to model_name (extended to support Qwen)
    if model_type == "gemma2_2b":
        model_name = "google/gemma-2-2b"
    elif model_type == "gemma2_9b":
        model_name = "google/gemma-2-9b"
    elif model_type == "gemma2_9b_it":
        model_name = "google/gemma-2-9b-it"
    elif model_type == "Qwen/Qwen2.5-3B":
        model_name = "Qwen/Qwen2.5-3B"
    elif model_type == "Qwen/Qwen2.5-3B-Instruct":
        model_name = "Qwen/Qwen2.5-3B-Instruct"
    else:
        raise ValueError(f"Model type not supported {model_type}")

    features_by_layers = get_features_by_layers(args.features_file)

    # ===== Model/Tokenizer loading (parity with main script) =====
    # Resolve cache root & offline flag
    cache_dir = (
        os.environ.get("HF_HOME")
        or os.environ.get("HF_CACHE_DIR")
        or os.path.expanduser("~/.cache/huggingface")
    )
    offline = (os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1") or (os.environ.get("HF_LOCAL_ONLY", "0") == "1")

    # For Qwen models we typically need trust_remote_code=True (local code in snapshot)
    trust_remote = ("qwen" in model_name.lower())

    # If offline and a local snapshot exists, replace model_name with snapshot path
    model_name_resolved = _resolve_local_snapshot_if_available(model_name, cache_dir, offline)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_resolved,
        trust_remote_code=trust_remote,
        cache_dir=cache_dir,
        local_files_only=offline,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name_resolved,
        trust_remote_code=trust_remote,
        cache_dir=cache_dir,
        local_files_only=offline,
    ).to(device)

    final_layer_norm = model.model.norm
    lm_head = model.lm_head
    print(final_layer_norm.weight.shape, lm_head.weight.shape)

    saes = dict()
    neutral_sentence = args.neutral_sentence

    repo_root = _detect_repo_root_from_features(args.features_file)

    # combined path behavior (legacy)
    output_scores_combined = None
    if args.cache_path:
        os.makedirs(os.path.dirname(args.cache_path), exist_ok=True)
        if os.path.exists(args.cache_path):
            with open(args.cache_path, "r") as f:
                output_scores_combined = json.load(f)
        else:
            output_scores_combined = dict()

    for layer in features_by_layers:
        features = features_by_layers[layer]

        # Load SAE for this layer (local)
        sae = get_sae(
            model_type=model_type,
            layer=layer,
            saes=saes,
            backend="dl_local",
            dl_local_dir=args.dl_local_dir,
            device=device,
        )

        # Pretty print SAE info
        try:
            resolved_dir = _swap_layer_number_in_path(args.dl_local_dir, layer)
        except Exception:
            resolved_dir = args.dl_local_dir
        trainer_class = getattr(sae, "trainer_class_name", "") or ""
        print(f"\n[SAE-LOAD] layer={layer} path='{resolved_dir}' trainer='{trainer_class}'")

        # Decide cache path for this layer
        if args.cache_path is None:
            trainer_id = _trainer_id_from_path(resolved_dir)
            cache_path_this_layer = _auto_cache_path(
                repo_root=repo_root,
                model_type=model_type,
                layer=layer,
                trainer_class=trainer_class,
                trainer_id=trainer_id,
                topk_for_filename=logit_lens_k,
                k_conf=k_conf,
            )
        else:
            cache_path_this_layer = args.cache_path

        # Prepare per-layer dict
        if args.cache_path is None:
            if os.path.exists(cache_path_this_layer):
                with open(cache_path_this_layer, "r") as f:
                    output_scores = json.load(f)
            else:
                output_scores = dict()
        else:
            output_scores = output_scores_combined

        # Early-exit if layer fully done
        if args.cache_path is None:
            needed = {f"{layer}_{int(f)}" for f in features}
            if all((k in output_scores) and isinstance(output_scores[k], dict) and ("score" in output_scores[k])
                   for k in needed):
                print(f"[SKIP LAYER] layer {layer} already done at {cache_path_this_layer} (features={len(features)})")
                os.makedirs(os.path.dirname(cache_path_this_layer), exist_ok=True)
                continue

        # Build logit-lens table (for representative tokens)
        model = model.cpu()
        logit_lens_topk, logit_lens_confidence, logit_lens_raw_logits = cache_logit_lens(
            layer=layer,
            saes=saes,
            model_type=model_type,
            final_layer_norm=final_layer_norm,
            lm_head=lm_head,
            k=logit_lens_k,
            backend="dl_local",
            dl_local_dir=args.dl_local_dir,
            device=device,
        )
        model = model.to(device)

        for feature in tqdm(features):
            feature = int(feature)
            key = f"{layer}_{feature}"

            # Skip if done
            existing = output_scores.get(key)
            if isinstance(existing, dict) and ("score" in existing):
                print(f"[SKIP] already scored {key}")
                continue

            # logit-lens representative tokens
            ll_indices = logit_lens_topk.indices[feature, :].tolist()
            tokens = tokenizer.convert_ids_to_tokens(ll_indices)
            print(f"[Layer {layer} | Feature {feature}] top-{logit_lens_k} tokens: {tokens[:8]} ...")

            # Compute metrics (two passes)
            metrics = compute_all_metrics(
                layer=layer,
                feature=feature,
                logit_lens_indices=ll_indices,
                sentence=neutral_sentence,
                sae=sae,
                tokenizer=tokenizer,
                model=model,
                device=device,
                amp_factor=int(args.amp_factor),
                k_conf=k_conf,
            )

            entry = output_scores.get(key, {})
            if not isinstance(entry, dict):
                entry = {}
            entry.update(metrics)
            entry["tokens"] = tokens
            output_scores[key] = entry

            print(
                f"[Layer {layer} | Feature {feature}] score={metrics['score']:.6f} "
                f"dH={metrics['delta_entropy']:.4f} dC={metrics['delta_confidence']:.4f}"
            )

            torch.cuda.empty_cache()
            gc.collect()

        # Save file
        os.makedirs(os.path.dirname(cache_path_this_layer), exist_ok=True)
        with open(cache_path_this_layer, "w") as f:
            json.dump(output_scores, f, ensure_ascii=False, indent=2)

        if args.cache_path is not None:
            output_scores_combined = output_scores


if __name__ == "__main__":
    main()
