# main.py
"""
Main entry point for HMNS experiments.

Usage
-----
python main.py \
  --model_alias llama2-7b-chat \
  --split test \
  --limit 200 \
  --prefer_hf \
  --out_dir runs/llama2-7b_test

Assumptions
-----------
- models.py provides:
    - MODEL_ALIASES: Dict[str, str]
    - ensure_repo_cached(repo_id: str) -> str
    - get_tokenizer_and_model(repo_id: str, dtype: str = "bfloat16", device_map: str = "auto")

- method.py provides:
    - HMNSConfig dataclass (topk_heads, attempts, steer_scale, max_layers, seed, kv_cache, alpha_base, alpha_step)
    - run_hmns(model, tokenizer, prompt: str, config: HMNSConfig) -> Dict[str, Any]
      (returns keys: success: bool, attempts: int, ipc: int, latency_s: float, flops: float, output: str)

- datasets.py provides:
    - build_main_pool_and_splits(...)
    - save_jsonl(...)
    - Example dataclass with fields (id, source, prompt, label)
"""

from __future__ import annotations
import argparse
import os
import time
import json
from typing import Any, Dict, List

from models import MODEL_ALIASES, ensure_repo_cached, get_tokenizer_and_model
from method import HMNSConfig, run_hmns
from datasets import build_main_pool_and_splits, save_jsonl, Example


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Run HMNS experiments.")
    p.add_argument("--model_alias", type=str, required=True,
                   choices=list(MODEL_ALIASES.keys()),
                   help=f"Model alias. Known: {list(MODEL_ALIASES.keys())}")
    p.add_argument("--split", type=str, default="test",
                   choices=["analysis", "dev", "test"], help="Which split to evaluate.")
    p.add_argument("--prefer_hf", action="store_true",
                   help="Load benchmarks from Hugging Face Datasets if available.")
    p.add_argument("--limit", type=int, default=0,
                   help="Optional limit on number of examples (0 = all).")
    p.add_argument("--seed", type=int, default=0, help="Global seed.")
    p.add_argument("--dtype", type=str, default="bfloat16",
                   choices=["bfloat16", "float16", "float32"], help="Model dtype.")
    p.add_argument("--device_map", type=str, default="auto",
                   help='Device map for HF accelerate (e.g., "auto").')
    p.add_argument("--max_layers", type=int, default=None,
                   help="Optional cap on analyzed layers for speed.")
    p.add_argument("--topk_heads", type=int, default=10, help="Global top-K heads per attempt.")
    p.add_argument("--attempts", type=int, default=3, help="Closed-loop attribution/steering attempts.")
    p.add_argument("--alpha_base", type=float, default=0.25, help="Base steering scale.")
    p.add_argument("--alpha_step", type=float, default=0.10, help="Per-attempt fractional increment.")
    p.add_argument("--kv_cache", action="store_true",
                   help="Enable KV cache for decoding (HMNS internals typically keep it off).")
    p.add_argument("--out_dir", type=str, required=True, help="Output directory.")
    return p.parse_args()


def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
    n = len(rows)
    if n == 0:
        return {"count": 0, "asr": 0.0, "acq": 0.0, "ipc": 0.0, "flops": 0.0, "latency_s": 0.0}
    asr = sum(1 for r in rows if r.get("success", False)) / n
    acq = sum(r.get("attempts", 0) for r in rows) / n
    ipc = sum(r.get("ipc", 0) for r in rows) / n
    flops = sum(float(r.get("flops", 0.0)) for r in rows) / n
    lat = sum(float(r.get("latency_s", 0.0)) for r in rows) / n
    return {
        "count": n,
        "asr": asr,
        "acq": acq,
        "ipc": ipc,
        "flops": flops,
        "latency_s": lat,
    }


def main():
    args = parse_args()
    os.makedirs(args.out_dir, exist_ok=True)

    # ------------------ Data ------------------
    pool, splits = build_main_pool_and_splits(
        prefer_hf=args.prefer_hf,
        local_roots=None,         # supply dict for local JSONL if not using HF
        hf_splits=None,           # default to "test" for all
        sizes=(150, 579, 196),    # fixed split as described in the paper
        seed=args.seed,
    )
    data: List[Example] = splits[args.split]
    if args.limit and args.limit > 0:
        data = data[: args.limit]

    # ------------------ Model -----------------
    repo_id = MODEL_ALIASES[args.model_alias]
    ensure_repo_cached(repo_id)
    tok, model = get_tokenizer_and_model(repo_id, dtype=args.dtype, device_map=args.device_map)

    # ------------------ HMNS config -----------
    cfg = HMNSConfig(
        topk_heads=args.topk_heads,
        attempts=args.attempts,
        steer_scale=args.alpha_base,   # base scale; run_hmns can schedule per attempt
        max_layers=args.max_layers,
        seed=args.seed,
        kv_cache=args.kv_cache,
        alpha_base=args.alpha_base,
        alpha_step=args.alpha_step,
    )

    # ------------------ Run loop --------------
    results: List[Dict[str, Any]] = []
    t0 = time.time()
    for ex in data:
        # Ex: ex.label == 1 implies malicious/policy-violating prompt (target class for ASR).
        prompt = ex.prompt
        res = run_hmns(model, tok, prompt, cfg)

        results.append({
            "id": ex.id,
            "source": ex.source,
            "label": int(ex.label),
            "prompt": prompt,
            # HMNS outputs
            "success": bool(res.get("success", False)),
            "attempts": int(res.get("attempts", 0)),   # external decode count
            "ipc": int(res.get("ipc", 0)),             # internal forward-equivalent passes
            "latency_s": float(res.get("latency_s", 0.0)),
            "flops": float(res.get("flops", 0.0)),
            "output": res.get("output", ""),
        })

    # ------------------ Save ------------------
    per_file = os.path.join(args.out_dir, f"{args.model_alias}_{args.split}.jsonl")
    save_jsonl(per_file, results)

    summary = summarize(results)
    sum_file = os.path.join(args.out_dir, f"{args.model_alias}_{args.split}_summary.json")
    with open(sum_file, "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)

    # Pretty print
    print("\n=== Summary ===")
    print(json.dumps(summary, indent=2))
    print(f"\nWrote per-example results to: {per_file}")
    print(f"Wrote summary to:            {sum_file}")
    print(f"Total wall time:             {time.time() - t0:.2f}s")


if __name__ == "__main__":
    main()
