# autointerp_hf/run_llm_eval.py
"""
Run the Auto-Interp (LLM-judge) evaluation on one or many SAEs.

This runner is aligned with a "delta LM loss" style SAE folder layout:

SAE discovery via --ae_root:
  - If --ae_root is a file path to ae.pt -> evaluate that SAE.
  - If --ae_root is a directory containing ae.pt (+ cfg.json/config.json) -> evaluate that SAE.
  - If --ae_root is a top-level directory -> recursively find all SAE dirs like:
        .../final_*/ae.pt   (and cfg.json or config.json in the same directory)

Per-SAE layer discovery:
  - We read cfg.json / config.json in the SAE directory and try to extract a layer index.
  - If --hook_module_path=auto, we map that layer index to a concrete dotted module path
    using hooks.guess_hook_module_path(model, layer).

Held-out data:
  - Prefer local JSONL via --heldout_jsonl. Each line must be JSON and contain --heldout_text_key.
    The field can be a string or a list[str].
  - Otherwise fall back to HF streaming dataset via --dataset_name (optional).

Outputs:
  - Results are saved inside each SAE directory:
      <sae_dir>/autointerp_llm_eval.json
      <sae_dir>/autointerp_llm_eval_ckpt.jsonl  (incremental checkpoint; safe to resume)
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
import random
from dataclasses import asdict
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

from .autointerp import AutoInterpRunner
from .config import AutoInterpEvalConfig
from .hooks import get_feature_activation_sparsity_hf, guess_hook_module_path
from .judge import APIJudgeConfig, AsyncAPIJudge  # updated judge uses aiohttp/OpenAI-compatible ChatCompletions
from .utils import load_sae


# ----------------------------
# SAE discovery
# ----------------------------

def _is_sae_dir(d: str) -> bool:
    """A directory is considered an SAE dir if it contains ae.pt and a cfg/config JSON file."""
    if not os.path.isdir(d):
        return False
    if not os.path.isfile(os.path.join(d, "ae.pt")):
        return False
    if os.path.isfile(os.path.join(d, "cfg.json")) or os.path.isfile(os.path.join(d, "config.json")):
        return True
    return False


def find_final_sae_dirs(ae_root: str) -> List[str]:
    """
    Find SAE directories under `ae_root`.

    Accepts:
      - a path to ae.pt
      - a directory containing ae.pt
      - a top-level directory; scan for final_*/ae.pt (with cfg.json/config.json)
    """
    ae_root = os.path.expanduser(ae_root)

    # Case 1: ae_root is directly ae.pt
    if os.path.isfile(ae_root) and os.path.basename(ae_root) == "ae.pt":
        d = os.path.dirname(ae_root)
        return [d] if _is_sae_dir(d) else []

    # Case 2: ae_root is a directory that itself is a SAE dir
    if os.path.isdir(ae_root) and _is_sae_dir(ae_root):
        return [ae_root]

    # Case 3: recursively scan for final_*/ae.pt
    found: List[str] = []
    if os.path.isdir(ae_root):
        for dirpath, _, filenames in os.walk(ae_root):
            base = os.path.basename(dirpath)
            if not base.startswith("final_"):
                continue
            if "ae.pt" not in filenames:
                continue
            if _is_sae_dir(dirpath):
                found.append(dirpath)

    return sorted(found)


def _load_json_if_exists(path: str) -> Optional[dict]:
    if not os.path.isfile(path):
        return None
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_sae_cfg(sae_dir: str) -> dict:
    """Load cfg.json/config.json from an SAE directory."""
    cfg = _load_json_if_exists(os.path.join(sae_dir, "cfg.json"))
    if cfg is not None:
        return cfg
    cfg = _load_json_if_exists(os.path.join(sae_dir, "config.json"))
    if cfg is not None:
        return cfg
    raise FileNotFoundError(f"No cfg.json/config.json found in SAE dir: {sae_dir}")


def _find_first_int_by_keys(obj: Any, keys: Tuple[str, ...]) -> Optional[int]:
    """DFS-search a nested dict/list structure for the first int under given keys."""
    if isinstance(obj, dict):
        for k in keys:
            v = obj.get(k, None)
            if isinstance(v, int):
                return v
            if isinstance(v, str) and v.strip().isdigit():
                return int(v.strip())
        for v in obj.values():
            out = _find_first_int_by_keys(v, keys)
            if out is not None:
                return out
    elif isinstance(obj, list):
        for it in obj:
            out = _find_first_int_by_keys(it, keys)
            if out is not None:
                return out
    return None


def extract_layer_from_cfg(cfg: dict) -> Optional[int]:
    """
    Try best-effort extraction of layer index from SAE cfg.

    We prioritize common patterns:
      cfg["trainer"]["layer"]
    but also search for generic keys like "layer", "layer_idx", "layer_id".
    """
    if isinstance(cfg.get("trainer", None), dict):
        v = cfg["trainer"].get("layer", None)
        if isinstance(v, int):
            return v
        if isinstance(v, str) and v.strip().isdigit():
            return int(v.strip())

    return _find_first_int_by_keys(cfg, keys=("layer", "layer_idx", "layer_id"))


def resolve_hook_module_path(
    model: torch.nn.Module,
    hook_module_path_arg: str,
    layer: Optional[int],
) -> str:
    """
    Resolve --hook_module_path for a specific SAE.

    Supported:
      - "auto": infer from model + layer
      - templates containing "{layer}": format with layer
      - otherwise: return the string as-is
    """
    if hook_module_path_arg == "auto":
        if layer is None:
            raise ValueError("hook_module_path=auto requires a layer index in cfg.json/config.json.")
        return guess_hook_module_path(model, layer)

    if "{layer}" in hook_module_path_arg:
        if layer is None:
            raise ValueError("hook_module_path template requires a layer index in cfg.json/config.json.")
        return hook_module_path_arg.format(layer=layer)

    return hook_module_path_arg


# ----------------------------
# Held-out data loading
# ----------------------------

def iter_texts_from_jsonl(jsonl_path: str, text_key: str) -> Iterable[str]:
    """Yield text from a local JSONL file (text field can be str or list[str])."""
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            if text_key not in obj:
                continue
            val = obj[text_key]
            if isinstance(val, str):
                yield val
            elif isinstance(val, list):
                for it in val:
                    if isinstance(it, str):
                        yield it


def iter_texts_from_hf(dataset_name: str, split: str, text_key: str) -> Iterable[str]:
    """Yield text from an HF streaming dataset."""
    from datasets import load_dataset

    ds = load_dataset(dataset_name, split=split, streaming=True)
    for row in ds:
        val = row.get(text_key, None)
        if isinstance(val, str):
            yield val
        elif isinstance(val, list):
            for it in val:
                if isinstance(it, str):
                    yield it


def build_token_chunks(
    tokenizer,
    texts: Iterable[str],
    total_tokens: int,
    context_length: int,
    skip_first_n_examples: int = 0,
    add_eos_between_texts: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Convert a text stream into (input_ids, attention_mask) with shape [N, context_length].

    Strategy (simple and deterministic):
      - tokenize each text without adding special tokens
      - optionally append eos_token_id as a separator
      - concatenate all tokens into a single stream
      - take the first (N * context_length) tokens where N = total_tokens // context_length
      - reshape into [N, context_length]
    """
    # Skip first N examples (for "held-out region" simulation)
    it = iter(texts)
    for _ in range(skip_first_n_examples):
        try:
            next(it)
        except StopIteration:
            break

    target_tokens = (total_tokens // context_length) * context_length
    if target_tokens <= 0:
        raise ValueError("total_tokens must be >= context_length.")

    eos_id = getattr(tokenizer, "eos_token_id", None)
    token_stream: List[int] = []

    pbar = tqdm(total=target_tokens, desc="Tokenizing held-out data", unit="tok", dynamic_ncols=True)
    for txt in it:
        if len(token_stream) >= target_tokens:
            break
        if not isinstance(txt, str) or len(txt.strip()) == 0:
            continue
        ids = tokenizer(txt, add_special_tokens=False).get("input_ids", [])
        if not ids:
            continue
        token_stream.extend(ids)
        if add_eos_between_texts and eos_id is not None:
            token_stream.append(int(eos_id))
        if len(token_stream) > target_tokens:
            token_stream = token_stream[:target_tokens]
        pbar.n = len(token_stream)
        pbar.refresh()
    pbar.close()

    if len(token_stream) < target_tokens:
        raise RuntimeError(
            f"Not enough tokens in held-out data: got {len(token_stream)}, need {target_tokens}. "
            f"Consider reducing --total_tokens or using a larger dataset."
        )

    ids_tensor = torch.tensor(token_stream, dtype=torch.long).view(-1, context_length)
    attn_tensor = torch.ones_like(ids_tensor, dtype=torch.long)
    return ids_tensor, attn_tensor


# ----------------------------
# Resume / checkpoint helpers
# ----------------------------

def load_checkpoint_jsonl(ckpt_path: str) -> Dict[int, dict]:
    """Load incremental results from a JSONL checkpoint file."""
    results: Dict[int, dict] = {}
    if not os.path.exists(ckpt_path):
        return results
    with open(ckpt_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            latent_id = int(obj.get("latent", obj.get("latent_id")))
            results[latent_id] = obj
    return results


def append_checkpoint_jsonl(ckpt_path: str, latent_result: dict) -> None:
    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
    with open(ckpt_path, "a", encoding="utf-8") as f:
        f.write(json.dumps(latent_result, ensure_ascii=False) + "\n")


# ----------------------------
# Core evaluation routines
# ----------------------------

def choose_alive_latents(
    sparsity: torch.Tensor,
    total_tokens: int,
    dead_latent_threshold: int,
) -> List[int]:
    """
    Decide which latents are "alive" based on estimated #activations.

    estimated_activations(latent) = sparsity(latent) * total_tokens
    alive if >= dead_latent_threshold
    """
    estimated_acts = sparsity * float(total_tokens)
    alive = torch.where(estimated_acts >= float(dead_latent_threshold))[0]
    return alive.tolist()


async def run_latents_in_chunks(
    cfg: AutoInterpEvalConfig,
    runner: AutoInterpRunner,
    latents: List[int],
    latent_batch_size: int,
    ckpt_path: str,
    resume_results: Dict[int, dict],
) -> Dict[int, dict]:
    """Run the auto-interp for selected latents in chunks and checkpoint after each latent."""
    final: Dict[int, dict] = dict(resume_results)

    # Only run missing latents
    todo = [z for z in latents if int(z) not in final]
    if not todo:
        return final

    for start in range(0, len(todo), latent_batch_size):
        chunk = todo[start : start + latent_batch_size]
        runner.latents = chunk
        chunk_results = await runner.run()

        # Save each latent immediately so a crash does not lose progress
        for lid, res in chunk_results.items():
            final[int(lid)] = res
            append_checkpoint_jsonl(ckpt_path, res)

    return final


def derive_output_paths(sae_dir: str) -> Tuple[str, str]:
    """
    Save inside the SAE directory (same as delta LM loss eval style).
    """
    out_path = os.path.join(sae_dir, "autointerp_llm_eval.json")
    ckpt_path = os.path.join(sae_dir, "autointerp_llm_eval_ckpt.jsonl")
    return out_path, ckpt_path


async def eval_single_sae_dir(
    *,
    sae_dir: str,
    model: torch.nn.Module,
    tokenizer,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    base_cfg: AutoInterpEvalConfig,
    judge: AsyncAPIJudge,
    hook_module_path: str,
    n_latents: int,
    latent_batch_size: int,
    dead_latent_threshold: int,
    seed: int,
) -> None:
    """
    Evaluate one SAE directory and write outputs into that directory.
    """
    out_path, ckpt_path = derive_output_paths(sae_dir)

    # Load SAE weights
    sae_pt = os.path.join(sae_dir, "ae.pt")
    sae, _sae_cfg = load_sae(sae_pt)

    device = base_cfg.device
    sae = sae.to(device=device)
    sae.eval()

    # Make per-SAE config clone (so each SAE can have its own hook path)
    cfg_local = AutoInterpEvalConfig(**asdict(base_cfg))
    cfg_local.hook_module_path = hook_module_path

    # Resume
    resume_results: Dict[int, dict] = {}
    if os.path.exists(out_path):
        with open(out_path, "r", encoding="utf-8") as f:
            try:
                resume_results = {int(k): v for k, v in json.load(f).items()}
            except Exception:
                resume_results = {}
    resume_results.update(load_checkpoint_jsonl(ckpt_path))

    # Compute sparsity to filter dead latents
    print(f"[SAE] Computing sparsity for: {sae_dir}")
    sparsity = get_feature_activation_sparsity_hf(
        input_ids=input_ids,
        attention_mask=attention_mask,
        model=model,
        sae=sae,
        batch_size=cfg_local.batch_size,
        hook_module_path=cfg_local.hook_module_path,
        tokenizer=tokenizer,
    )
    alive_latents = choose_alive_latents(
        sparsity=sparsity,
        total_tokens=cfg_local.total_tokens,
        dead_latent_threshold=dead_latent_threshold,
    )
    print(f"[SAE] Alive latents: {len(alive_latents)}/{int(sparsity.shape[0])}")

    # Sample latents
    rng = random.Random(seed)
    if len(alive_latents) == 0:
        print(f"[SAE] No alive latents found; skipping: {sae_dir}")
        return

    if len(alive_latents) >= n_latents:
        selected = rng.sample(alive_latents, k=n_latents)
    else:
        selected = alive_latents[:]  # fewer than requested

    selected = sorted(set(int(x) for x in selected))
    print(f"[SAE] Selected {len(selected)} latents for interpretation.")

    # Run auto-interp
    runner = AutoInterpRunner(
        cfg=cfg_local,
        model=model,
        sae=sae,
        tokenizer=tokenizer,
        input_ids=input_ids,
        attention_mask=attention_mask,
        latents=selected,
        judge=judge,
    )

    results = await run_latents_in_chunks(
        cfg=cfg_local,
        runner=runner,
        latents=selected,
        latent_batch_size=latent_batch_size,
        ckpt_path=ckpt_path,
        resume_results=resume_results,
    )

    # Save final results
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump({str(k): v for k, v in results.items()}, f, indent=2, ensure_ascii=False)

    print(f"[SAE] Saved results: {out_path}")


# ----------------------------
# CLI
# ----------------------------

def _dtype_from_str(s: str) -> torch.dtype:
    s = s.lower()
    if s in ("bf16", "bfloat16"):
        return torch.bfloat16
    if s in ("fp16", "float16"):
        return torch.float16
    if s in ("fp32", "float32"):
        return torch.float32
    raise ValueError(f"Unknown torch dtype: {s}")


def _normalize_chat_completions_url(base_url: str) -> str:
    """
    Normalize base_url so it points to an OpenAI-compatible ChatCompletions endpoint.

    Examples:
      - https://api.foo/v1           -> https://api.foo/v1/chat/completions
      - https://api.foo/v1/          -> https://api.foo/v1/chat/completions
      - https://api.foo/v1/chat/completions (unchanged)
      - https://dashscope.../v1/chat/completions (unchanged)
    """
    u = base_url.rstrip("/")
    if u.endswith("/chat/completions"):
        return u
    if u.endswith("/v1"):
        return u + "/chat/completions"
    return u  # assume user passed a full endpoint already


async def main_async() -> None:
    parser = argparse.ArgumentParser()

    # Model
    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"])
    parser.add_argument("--trust_remote_code", action="store_true", default=True)

    # SAE
    parser.add_argument("--ae_root", type=str, required=True)
    parser.add_argument("--hook_module_path", type=str, default="auto")

    # Data
    parser.add_argument("--heldout_jsonl", type=str, default=None)
    parser.add_argument("--heldout_text_key", type=str, default="text")
    parser.add_argument("--skip_first_n_examples", type=int, default=0)
    parser.add_argument("--dataset_name", type=str, default=None)
    parser.add_argument("--dataset_split", type=str, default="train")
    parser.add_argument("--context_length", type=int, default=128)
    parser.add_argument("--total_tokens", type=int, default=2_000_000)
    parser.add_argument("--batch_size", type=int, default=64)

    # Latent selection
    parser.add_argument("--n_latents", type=int, default=100)
    parser.add_argument("--latent_batch_size", type=int, default=100)
    parser.add_argument("--dead_latent_threshold", type=int, default=15)
    parser.add_argument("--seed", type=int, default=3407)

    # Judge (new API calling standard: OpenAI-compatible ChatCompletions)
    parser.add_argument("--judge_model", type=str, required=True)
    parser.add_argument("--judge_base_url", type=str, required=True)
    parser.add_argument("--judge_api_key", type=str, default=None)  # may also be provided via env
    parser.add_argument("--judge_timeout", type=float, default=60.0)
    parser.add_argument("--judge_max_retries", type=int, default=3)
    parser.add_argument("--judge_max_concurrent", type=int, default=10)
    parser.add_argument("--judge_stream", action="store_true", default=False)
    parser.add_argument("--judge_provider", type=str, default=None)
    parser.add_argument("--judge_close_dash_inspect", action="store_true", default=True)
    parser.add_argument("--judge_temperature", type=float, default=0.0)
    parser.add_argument("--judge_max_tokens_expl", type=int, default=512)
    parser.add_argument("--judge_max_tokens_score", type=int, default=16)

    args = parser.parse_args()

    # Seed everything deterministic that matters here
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.device)
    dtype = _dtype_from_str(args.torch_dtype)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=args.trust_remote_code,
        use_fast=True,
    )
    if tokenizer.pad_token_id is None:
        # Keep consistent with many AR models
        tokenizer.pad_token = tokenizer.eos_token

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=args.trust_remote_code,
        torch_dtype=dtype,
    )
    model.to(device)
    model.eval()

    # Load held-out tokens ONCE (reused for all SAEs)
    if args.heldout_jsonl is not None:
        text_iter = iter_texts_from_jsonl(args.heldout_jsonl, text_key=args.heldout_text_key)
    else:
        if args.dataset_name is None:
            raise ValueError("Either --heldout_jsonl or --dataset_name must be provided.")
        text_iter = iter_texts_from_hf(args.dataset_name, split=args.dataset_split, text_key=args.heldout_text_key)

    input_ids, attention_mask = build_token_chunks(
        tokenizer=tokenizer,
        texts=text_iter,
        total_tokens=args.total_tokens,
        context_length=args.context_length,
        skip_first_n_examples=args.skip_first_n_examples,
        add_eos_between_texts=True,
    )
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    # Discover SAE dirs
    sae_dirs = find_final_sae_dirs(args.ae_root)
    if len(sae_dirs) == 0:
        raise RuntimeError(
            "No SAE directories found under --ae_root. Expected:\n"
            "  - ae.pt\n"
            "  - a dir containing ae.pt + (cfg.json/config.json)\n"
            "  - a top-level dir containing final_*/ae.pt + cfg/config\n"
            f"Got: {args.ae_root}"
        )

    print(f"[Scan] Found {len(sae_dirs)} SAE dirs.")
    for d in sae_dirs[:10]:
        print(f"  - {d}")
    if len(sae_dirs) > 10:
        print("  ...")

    # Build judge (new API calling style)
    api_key = args.judge_api_key or os.environ.get("JUDGE_API_KEY") or os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("Missing judge API key. Provide --judge_api_key or set JUDGE_API_KEY / OPENAI_API_KEY.")

    judge_cfg = APIJudgeConfig(
        model=args.judge_model,
        api_key=api_key,
        base_url=_normalize_chat_completions_url(args.judge_base_url),
        timeout=args.judge_timeout,
        max_retries=args.judge_max_retries,
        max_concurrent=args.judge_max_concurrent,
        stream=args.judge_stream,
        provider=args.judge_provider,
        close_dash_inspect=args.judge_close_dash_inspect,
        temperature=args.judge_temperature,
        max_tokens_explanation=args.judge_max_tokens_expl,
        max_tokens_scoring=args.judge_max_tokens_score,
        debug=False,
    )
    judge = AsyncAPIJudge(judge_cfg)

    # Base config for auto-interp
    base_cfg = AutoInterpEvalConfig(
        model_name_or_path=args.model_name_or_path,
        dataset_name=args.heldout_jsonl or (args.dataset_name or "local_jsonl"),
        hook_module_path=args.hook_module_path,  # per-SAE override happens later
        total_tokens=args.total_tokens,
        llm_context_size=args.context_length,
        batch_size=args.batch_size,
        random_seed=args.seed,
        device=device,
        dtype=dtype,
    )

    # Run for each SAE dir
    for idx, sae_dir in enumerate(sae_dirs, start=1):
        print(f"\n[Eval {idx}/{len(sae_dirs)}] SAE dir: {sae_dir}")

        cfg = load_sae_cfg(sae_dir)
        layer = extract_layer_from_cfg(cfg)
        hook_path = resolve_hook_module_path(model, args.hook_module_path, layer)
        print(f"[Eval] layer={layer} hook_module_path={hook_path}")

        await eval_single_sae_dir(
            sae_dir=sae_dir,
            model=model,
            tokenizer=tokenizer,
            input_ids=input_ids,
            attention_mask=attention_mask,
            base_cfg=base_cfg,
            judge=judge,
            hook_module_path=hook_path,
            n_latents=args.n_latents,
            latent_batch_size=args.latent_batch_size,
            dead_latent_threshold=args.dead_latent_threshold,
            seed=args.seed,
        )


def main() -> None:
    asyncio.run(main_async())


if __name__ == "__main__":
    main()
