# eval_prompt_steering.py
# Prompt-based steering baseline, aligned with AxBench J.1 template and your judge pipeline.
# - Generates a steering prompt from a concept using the Appendix J.1 template
# - Prepends it to the user instruction ("Question: ...")
# - Generates with a base HF model (e.g., Gemma-2-2b-it)
# - Scores with the judge (AsyncOpenAI or LocalHF), 0.0–2.0 floats + harmonic mean
# - Saves to: cache/results_prompt_eval_openai/<base_model>/prompt_baseline/<concepts_id>/{eval.json,samples.parquet}
# - Prints clear progress: SAE index/total (if provided) and concept k/N using tqdm.

from __future__ import annotations
import json, argparse, random, re, os, sys
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from pathlib import Path

# Make ../ available (utils, judge, etc.)
THIS_DIR = os.path.dirname(__file__)
SRC_DIR  = os.path.abspath(os.path.join(THIS_DIR, ".."))
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd

# Reuse your judge + helpers
from judge import (
    AsyncOpenAIJudge, OpenAIJudgeConfig,
    LocalHFJudge, LocalJudgeConfig,
    harmonic_mean_0_2,
)
from eval_sae_steering import load_instructions_file, truncate_text, safe_float


# =========================
# Appendix J.1 steering-prompt template (verbatim)
# =========================
STEERING_PROMPT_GEN_TEMPLATE = """LLM-based steering prompt generation

Generate a prompt to guide a language model in producing responses.

Objective: Direct the model to include content related to [{concept}] (the concept) in its responses. Ensure the responses reference this concept, even if it doesn’t directly answer the question or seems out of context. Optionally, provide in-context examples to reinforce this behaviour.

Return only the final prompt without any additional text.
"""

def compose_steered_input(steering_prompt: str, instruction: str) -> str:
    steering_prompt = steering_prompt.strip()
    instruction = instruction.strip()
    return f"{steering_prompt}\n\nQuestion: {instruction}"


# Concepts file can be:
# 1) list[str]
# 2) dict[str, str]  => { "c0": "concept desc", ... }
def load_concepts(path: str) -> List[Tuple[str, str]]:
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    items: List[Tuple[str, str]] = []
    if isinstance(obj, list):
        for i, s in enumerate(obj):
            if isinstance(s, str) and s.strip():
                items.append((f"c{i}", s.strip()))
    elif isinstance(obj, dict):
        for k, v in obj.items():
            if isinstance(v, str) and v.strip():
                items.append((str(k), v.strip()))
    else:
        raise ValueError(f"Unsupported concepts file format: {type(obj)}")
    if not items:
        raise ValueError(f"No valid concepts in {path}")
    return items


# ---------- Prompt generation backends ----------
@dataclass
class PromptGenConfig:
    backend: str = "openai_async"          # openai_async | hf_local
    model: str = "gpt-4o-mini"             # OpenAI (or HF model if hf_local)
    base_url: str = "https://api.shubiaobiao.cn/v1"
    timeout: float = 60.0
    max_new_tokens: int = 256              # hf_local only
    device: str = "cuda:0"                 # hf_local only
    cache_path: str | None = None          # cache steering prompts by concept key

class PromptEngineer:
    def __init__(self, cfg: PromptGenConfig):
        self.cfg = cfg
        self._client = None
        self._tok = None
        self._model = None

        # cache
        self.cache: Dict[str, str] = {}
        if cfg.cache_path and os.path.exists(cfg.cache_path):
            try:
                self.cache = json.load(open(cfg.cache_path, "r", encoding="utf-8"))
            except Exception:
                self.cache = {}

        if cfg.backend == "openai_async":
            from openai import AsyncOpenAI
            import httpx
            self._client = AsyncOpenAI(
                api_key=os.environ.get("OPENAI_API_KEY"),
                base_url=cfg.base_url,
                timeout=cfg.timeout,
                http_client=httpx.AsyncClient(
                    limits=httpx.Limits(max_keepalive_connections=100, max_connections=1000),
                    headers={"Connection": "close"},
                ),
                max_retries=3,
            )
        elif cfg.backend == "hf_local":
            self._tok = AutoTokenizer.from_pretrained(cfg.model)
            self._model = AutoModelForCausalLM.from_pretrained(cfg.model).to(cfg.device)
            if self._tok.pad_token_id is None:
                self._tok.pad_token = self._tok.eos_token
        else:
            raise ValueError(f"Unknown backend: {cfg.backend}")

    async def _gen_openai(self, concept: str) -> str:
        from openai import AsyncOpenAI  # type: ignore
        tmpl = STEERING_PROMPT_GEN_TEMPLATE.format(concept=concept)
        resp = await self._client.chat.completions.create(
            model=self.cfg.model,
            temperature=0.0,
            messages=[
                {"role": "system", "content": "You are a helpful prompt engineer."},
                {"role": "user", "content": tmpl},
            ],
        )
        text = (resp.choices[0].message.content or "").strip()
        return text

    @torch.no_grad()
    def _gen_hf(self, concept: str) -> str:
        prompt = (
            f"<start_of_turn>system\nYou are a helpful prompt engineer.<end_of_turn>\n"
            f"<start_of_turn>user\n{STEERING_PROMPT_GEN_TEMPLATE.format(concept=concept)}<end_of_turn>\n"
            f"<start_of_turn>model\n"
        )
        inputs = self._tok(prompt, return_tensors="pt").to(self.cfg.device)
        out = self._model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=self.cfg.max_new_tokens,
            pad_token_id=self._tok.pad_token_id,
            eos_token_id=self._tok.eos_token_id,
        )
        text = self._tok.decode(out[0], skip_special_tokens=True)
        parts = re.split(r"<start_of_turn>model\s*", text, flags=re.IGNORECASE)
        text2 = parts[-1].strip() if parts else text.strip()
        return text2.strip()

    def gen_prompt(self, concept_key: str, concept_text: str) -> str:
        if concept_key in self.cache:
            cached = str(self.cache[concept_key]).strip()
            if cached:
                return cached
        if self.cfg.backend == "openai_async":
            import asyncio
            text = asyncio.run(self._gen_openai(concept_text))
        else:
            text = self._gen_hf(concept_text)
        text = text.strip().split("\n\n")[0].strip()
        if self.cfg.cache_path:
            self.cache[concept_key] = text
            try:
                json.dump(self.cache, open(self.cfg.cache_path, "w", encoding="utf-8"), ensure_ascii=False, indent=2)
            except Exception:
                pass
        return text


# ---------- save-path helpers ----------
def _san(name: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_.-]+", "_", name).strip("_")

def _build_save_paths(base_model: str, concepts_id: str, out_tag: str | None = None) -> Tuple[Path, Path, Path, Path]:
    tag = out_tag or concepts_id
    root = Path("cache") / "results_prompt_eval_openai" / _san(base_model) / "prompt_baseline" / _san(tag)
    root.mkdir(parents=True, exist_ok=True)
    eval_json = root / "eval.json"
    per_sample_parquet = root / "samples.parquet"
    prompt_cache = root / "prompt_cache.json"
    return root, eval_json, per_sample_parquet, prompt_cache


# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    # Base HF model used for response generation
    ap.add_argument("--base_model", type=str, default="google/gemma-2-2b-it")
    # Prompt-generation backend/model
    ap.add_argument("--prompt_gen_backend", type=str, default="openai_async", choices=["openai_async", "hf_local"])
    ap.add_argument("--prompt_gen_model", type=str, default="gpt-4o-mini")
    # Judge
    ap.add_argument("--judge_backend", type=str, default="openai_async", choices=["openai_async", "hf_local"])
    ap.add_argument("--judge_model", type=str, default="gpt-4o-mini")
    # Data
    ap.add_argument("--instructions_file", type=str, required=True)
    ap.add_argument("--concepts_file", type=str, required=True)
    # Eval settings
    ap.add_argument("--dev_k", type=int, default=5, help="Per concept: dev K items + holdout K items.")
    ap.add_argument("--max_new_tokens", type=int, default=128)
    ap.add_argument("--temperature", type=float, default=0.7)
    ap.add_argument("--top_p", type=float, default=0.95)
    ap.add_argument("--seed", type=int, default=123)
    # Progress / output tag
    ap.add_argument("--out_tag", type=str, default=None, help="Optional subfolder name; default=basename(concepts_file)")
    ap.add_argument("--sae_index", type=int, default=0, help="1-based index of this SAE among all SAEs (for logging).")
    ap.add_argument("--total_saes", type=int, default=0, help="Total number of SAEs (for logging).")
    # Debug printing
    ap.add_argument("--debug", action="store_true")
    ap.add_argument("--sample_print_k", type=int, default=1)
    ap.add_argument("--print_chars", type=int, default=300)

    args = ap.parse_args()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    random.seed(args.seed)

    concepts_id = os.path.splitext(os.path.basename(args.concepts_file))[0]
    sae_label = concepts_id
    sae_prefix = ""
    if args.sae_index > 0 and args.total_saes > 0:
        sae_prefix = f"[SAE {args.sae_index}/{args.total_saes}] "
    print(f"{sae_prefix}Starting PromptSteering evaluation for SAE: {sae_label}")

    root_dir, eval_json_path, sample_parquet_path, prompt_cache_path = _build_save_paths(
        args.base_model, concepts_id if args.out_tag is None else args.out_tag
    )

    # Prompt engineer and judge
    pe = PromptEngineer(PromptGenConfig(
        backend=args.prompt_gen_backend,
        model=args.prompt_gen_model,
        cache_path=str(prompt_cache_path),
        device=device
    ))
    judge = AsyncOpenAIJudge(OpenAIJudgeConfig(model=args.judge_model)) \
        if args.judge_backend == "openai_async" \
        else LocalHFJudge(LocalJudgeConfig(model_name=args.judge_model, device=device))

    print(f"[INFO] PromptGen backend/model : {args.prompt_gen_backend} / {args.prompt_gen_model}")
    print(f"[INFO] Judge backend/model     : {args.judge_backend} / {args.judge_model}")

    # Base model for response generation
    print(f"[INFO] Loading base model on {device}: {args.base_model}")
    tok = AutoTokenizer.from_pretrained(args.base_model)
    mdl = AutoModelForCausalLM.from_pretrained(args.base_model).to(device)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    # Data
    instructions_pool = load_instructions_file(args.instructions_file, min_needed=2 * args.dev_k)
    concept_items = load_concepts(args.concepts_file)
    total_concepts = len(concept_items)
    print(f"[INFO] Instructions: {len(instructions_pool)} | Concepts in this SAE: {total_concepts}")

    all_results: Dict[str, Any] = {}
    rows_for_parquet: List[Dict[str, Any]] = []

    # Concept-level progress bar
    pbar_desc = f"{sae_prefix}Concepts [{sae_label}]"
    for idx, (ckey, ctext) in enumerate(tqdm(concept_items, desc=pbar_desc, unit="concept"), start=1):
        # 1) Steering prompt (J.1 template)
        steering_prompt = pe.gen_prompt(ckey, ctext)

        # 2) Split dev/holdout
        rnd = random.Random((args.seed * 1000003) ^ hash(ckey))
        pool = instructions_pool[:]
        rnd.shuffle(pool)
        dev_insts = pool[: args.dev_k]
        hold_insts = pool[args.dev_k : 2 * args.dev_k]

        def gen_and_score(inst: str) -> Tuple[float, float, float, str]:
            steered_input = compose_steered_input(steering_prompt, inst)
            inputs = tok(steered_input, return_tensors="pt").to(device)
            out = mdl.generate(
                **inputs,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature if args.temperature > 0 else None,
                top_p=args.top_p,
                max_new_tokens=args.max_new_tokens,
                pad_token_id=tok.pad_token_id,
                eos_token_id=tok.eos_token_id,
            )
            resp = tok.decode(out[0], skip_special_tokens=True)
            c, i, f = judge.score(ctext, inst, resp)
            return safe_float(c), safe_float(i), safe_float(f), resp

        # Dev
        for i, inst in enumerate(dev_insts, 1):
            c, i_s, f, resp = gen_and_score(inst)
            overall = harmonic_mean_0_2(c, i_s, f)
            rows_for_parquet.append({
                "concept_key": ckey, "concept_text": ctext, "split": "dev",
                "instruction": inst, "steering_prompt": steering_prompt,
                "steered_input": compose_steered_input(steering_prompt, inst),
                "response": resp, "concept_score": c, "instruct_score": i_s,
                "fluency_score": f, "overall": overall,
            })
            if args.debug and i <= args.sample_print_k:
                print(f"[DEV sample] {ckey} ({idx}/{total_concepts})")
                print("  Instruction:", truncate_text(inst, args.print_chars))
                print("  Response   :", truncate_text(resp, args.print_chars))
                print(f"  Scores(c,i,f)->overall: ({c:.3f},{i_s:.3f},{f:.3f})->{overall:.3f}")

        # Holdout
        hold_scores, hold_c, hold_i, hold_f = [], 0.0, 0.0, 0.0
        for j, inst in enumerate(hold_insts, 1):
            c, i_s, f, resp = gen_and_score(inst)
            overall = harmonic_mean_0_2(c, i_s, f)
            hold_scores.append(overall); hold_c += c; hold_i += i_s; hold_f += f
            rows_for_parquet.append({
                "concept_key": ckey, "concept_text": ctext, "split": "hold",
                "instruction": inst, "steering_prompt": steering_prompt,
                "steered_input": compose_steered_input(steering_prompt, inst),
                "response": resp, "concept_score": c, "instruct_score": i_s,
                "fluency_score": f, "overall": overall,
            })
            if args.debug and j <= args.sample_print_k:
                print(f"[HOLD sample] {ckey} ({idx}/{total_concepts})")
                print("  Instruction:", truncate_text(inst, args.print_chars))
                print("  Response   :", truncate_text(resp, args.print_chars))
                print(f"  Scores(c,i,f)->overall: ({c:.3f},{i_s:.3f},{f:.3f})->{overall:.3f}")

        hold_mean = {
            "concept": hold_c / max(1, len(hold_insts)),
            "instruct": hold_i / max(1, len(hold_insts)),
            "fluency":  hold_f / max(1, len(hold_insts)),
            "overall":  sum(hold_scores) / max(1, len(hold_scores)),
        }
        all_results[ckey] = {
            "concept": ctext,
            "steering_prompt": steering_prompt,
            "holdout": {"mean": {k: float(v) for k, v in hold_mean.items()}}
        }

        # Incremental checkpoint after each concept
        with open(eval_json_path, "w", encoding="utf-8") as f:
            json.dump(all_results, f, ensure_ascii=False, indent=2)

    # Write per-sample parquet
    if rows_for_parquet:
        df = pd.DataFrame(rows_for_parquet)
        df.to_parquet(sample_parquet_path, index=False)
        print(f"[DONE] Wrote samples -> {sample_parquet_path}")

    print(f"[DONE] Wrote per-concept results -> {eval_json_path}")


if __name__ == "__main__":
    main()
