#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Run three experiments in one file:
  (A) main-only, (B) usage-only (raw/ortho), (C) main+usage (raw/ortho)
All user-editable params are grouped at the TOP and can also be overridden by CLI.
"""

import os, csv, argparse, time
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.set_grad_enabled(False)

# =========================
# ===== USER CONFIG =======
# =========================
# 基础
MODEL_NAME        = "meta-llama/Meta-Llama-3-8B-Instruct"
LAYER_TO_STEER    = 14            # 人类编号（从1开始）
SEED              = 42

# 生成设置
TEMPERATURE       = 0.0           # 0 => greedy
TOP_P             = 0.9
MAX_NEW_TOKENS    = 160
REPETITION_PEN    = 1.12
NO_REPEAT_NGRAM   = 3

# I/O
PROMPT_FILE       = "src/prompts.txt"         # 每行一个prompt
OUTPUT_DIR        = "outputs"
OUTPUT_BASENAME   = "steer_main_usage"    # 将生成 {OUTPUT_BASENAME}_YYYYmmdd-HHMMSS.csv
ADD_TIMESTAMP     = True

# 轴路径
MAIN_AXIS_PATH    = "experiments/exp_main_axis/Meta-Llama-3-8B-Instruct_main/sentiment_axis_L14.npy"
AXIS_USAGE_PATHS  = {
    "genre"  : "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_genre/sentiment_axis_L14.npy",
    "tone"   : "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_tone/sentiment_axis_L14.npy",
    "context": "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_contextual/sentiment_axis_L14.npy",
    "topic"  : "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_topic/sentiment_axis_L14.npy",
}
USAGES_TO_TRY     = ["genre", "tone", "context", "topic"]

# 实验网格（强度）
ALPHAS_MAIN_ONLY      = [-30, -15, 0, 15, 30]             # (A) main-only
BETAS_USAGE_ONLY      = [-30, -15, 0, 15, 30]             # (B) usage-only
ALPHAS_FOR_COMBO      = [10]                              # (C) main+usage: α
BETAS_FOR_COMBO       = [-40, -30, -15, 0, 15, 30, 40]    # (C) main+usage: β

# 其它
DTYPE                = torch.float16          # 模型/轴 dtype
DEVICE_MAP           = "auto"                 # "auto" 或具体设备映射
RUN_MAIN_ONLY        = True
RUN_USAGE_ONLY       = True
RUN_MAIN_PLUS_USAGE  = True
MODES_FOR_USAGE      = ["raw", "ortho"]       # 只看正交可改为 ["ortho"]
# =========================
# ===== END CONFIG ========
# =========================


# ===== Utils =====
def set_seed(s: int):
    import random
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)

def _normalize(v: torch.Tensor, eps=1e-12):
    return v / (v.norm() + eps)

def _proj(u: torch.Tensor, v_unit: torch.Tensor):
    return (u * v_unit).sum() * v_unit

def orthogonalize_to(vec: torch.Tensor, bases: list):
    """Gram–Schmidt 正交到 bases；bases 需单位化。"""
    w = vec
    for b in bases:
        w = w - _proj(w, b)
    n = w.norm()
    return w / (n + 1e-12)

def load_axis(path: str, device, dtype=DTYPE):
    v = torch.tensor(np.load(path), dtype=dtype, device=device)
    return _normalize(v)

def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(torch.nn.functional.cosine_similarity(
        a.unsqueeze(0).to(torch.float32), b.unsqueeze(0).to(torch.float32)
    ).item())

def ensure_dir(p: str):
    if p and not os.path.exists(p):
        os.makedirs(p, exist_ok=True)

def build_output_path(output_dir: str, base: str, add_ts: bool) -> str:
    ensure_dir(output_dir)
    ts = time.strftime("%Y%m%d-%H%M%S") if add_ts else ""
    name = f"{base}_{ts}.csv" if ts else f"{base}.csv"
    return os.path.join(output_dir, name)

def usage_axis(mode: str, usage_raw: torch.Tensor, u_main_vec: torch.Tensor | None):
    """返回 raw 或正交后的 usage 轴"""
    if mode == "raw" or (u_main_vec is None):
        return usage_raw
    u_main_f32  = _normalize(u_main_vec.to(dtype=torch.float32))
    u_usage_f32 = _normalize(usage_raw.to(dtype=torch.float32))
    u_usage_ortho = orthogonalize_to(u_usage_f32, [u_main_f32])
    return u_usage_ortho.to(dtype=(u_main_vec.dtype if u_main_vec is not None else DTYPE),
                            device=(u_main_vec.device if u_main_vec is not None else usage_raw.device))

# ===== Model / Hook =====
def load_model_and_tokenizer(model_name: str, dtype=DTYPE, device_map=DEVICE_MAP):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map)
    return model, tokenizer

def register_hook_main_usage(model, layer_idx_human, alpha_main, u_main_vec, beta_usage, u_usage_vec):
    """
    通用 hook：同时支持 main + usage（其中任一可为0）
    h_last' = h_last + α·RMS(h_last)·u_main + β·RMS(h_last)·u_usage
    """
    phys_idx = layer_idx_human - 1  # 0-based
    def hook_fn(module, inputs, outputs):
        hidden = outputs[0]               # [B,T,H]
        h_last = hidden[:, -1, :]         # [B,H]
        rms    = h_last.pow(2).mean(dim=-1, keepdim=True).sqrt()  # [B,1]
        delta = torch.zeros_like(h_last)
        if u_main_vec is not None and alpha_main != 0.0:
            delta = delta + alpha_main * rms * u_main_vec.to(h_last.device, dtype=h_last.dtype)
        if u_usage_vec is not None and beta_usage != 0.0:
            delta = delta + beta_usage * rms * u_usage_vec.to(h_last.device, dtype=h_last.dtype)
        hidden = hidden.clone()
        hidden[:, -1, :] = h_last + delta
        return (hidden,) + outputs[1:]
    return model.model.layers[phys_idx].register_forward_hook(hook_fn)

@torch.inference_mode()
def generate_one(model, tokenizer, prompt: str,
                 layer_to_steer: int,
                 alpha_main: float, u_main_vec,
                 beta_usage: float, u_usage_vec,
                 max_new_tokens: int, temperature: float, top_p: float,
                 repetition_pen: float, no_repeat_ngram: int):
    h = register_hook_main_usage(model, layer_to_steer, alpha_main, u_main_vec, beta_usage, u_usage_vec)
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=(temperature > 0.0),
            repetition_penalty=repetition_pen,
            no_repeat_ngram_size=no_repeat_ngram,
            pad_token_id=tokenizer.pad_token_id,
        )
        gen = out[0][inputs["input_ids"].size(1):]
        return tokenizer.decode(gen, skip_special_tokens=True)
    finally:
        h.remove()

# ===== Main =====
def main():
    parser = argparse.ArgumentParser()
    # 覆盖顶部参数
    parser.add_argument("--model_name", default=MODEL_NAME)
    parser.add_argument("--layer", type=int, default=LAYER_TO_STEER)
    parser.add_argument("--seed", type=int, default=SEED)
    parser.add_argument("--temperature", type=float, default=TEMPERATURE)
    parser.add_argument("--top_p", type=float, default=TOP_P)
    parser.add_argument("--max_new_tokens", type=int, default=MAX_NEW_TOKENS)
    parser.add_argument("--rep_pen", type=float, default=REPETITION_PEN)
    parser.add_argument("--no_repeat", type=int, default=NO_REPEAT_NGRAM)
    parser.add_argument("--prompt_file", default=PROMPT_FILE)
    parser.add_argument("--output_dir", default=OUTPUT_DIR)
    parser.add_argument("--output_base", default=OUTPUT_BASENAME)
    parser.add_argument("--add_ts", action=argparse.BooleanOptionalAction, default=ADD_TIMESTAMP)

    # 哪些实验要跑
    parser.add_argument("--run_main_only", action=argparse.BooleanOptionalAction, default=RUN_MAIN_ONLY)
    parser.add_argument("--run_usage_only", action=argparse.BooleanOptionalAction, default=RUN_USAGE_ONLY)
    parser.add_argument("--run_main_plus_usage", action=argparse.BooleanOptionalAction, default=RUN_MAIN_PLUS_USAGE)

    # 网格可在命令行覆盖（逗号分隔）
    parser.add_argument("--alphas_main_only", default=",".join(map(str, ALPHAS_MAIN_ONLY)))
    parser.add_argument("--betas_usage_only", default=",".join(map(str, BETAS_USAGE_ONLY)))
    parser.add_argument("--alphas_for_combo", default=",".join(map(str, ALPHAS_FOR_COMBO)))
    parser.add_argument("--betas_for_combo", default=",".join(map(str, BETAS_FOR_COMBO)))

    # usage 模式 raw/ortho
    parser.add_argument("--modes_for_usage", default=",".join(MODES_FOR_USAGE))
    args = parser.parse_args()

    # 解析网格与模式
    def parse_list_floats(s): return [float(x) for x in s.split(",") if x.strip()!=""]
    alphas_main_only = parse_list_floats(args.alphas_main_only)
    betas_usage_only = parse_list_floats(args.betas_usage_only)
    alphas_for_combo = parse_list_floats(args.alphas_for_combo)
    betas_for_combo  = parse_list_floats(args.betas_for_combo)
    modes_for_usage  = [m.strip() for m in args.modes_for_usage.split(",") if m.strip()!=""]

    # 局部变量
    model_name      = args.model_name
    layer_to_steer  = args.layer
    seed            = args.seed
    temperature     = args.temperature
    top_p           = args.top_p
    max_new_tokens  = args.max_new_tokens
    rep_pen         = args.rep_pen
    no_repeat       = args.no_repeat
    prompt_file     = args.prompt_file
    output_dir      = args.output_dir
    output_base     = args.output_base
    add_ts          = args.add_ts
    run_main_only       = args.run_main_only
    run_usage_only      = args.run_usage_only
    run_main_plus_usage = args.run_main_plus_usage

    set_seed(seed)

    # 读 prompts
    if not os.path.exists(prompt_file):
        raise FileNotFoundError(f"PROMPT_FILE not found: {prompt_file}")
    with open(prompt_file, "r", encoding="utf-8") as f:
        PROMPTS = [ln.strip() for ln in f if ln.strip()]
    print(f"[INFO] Loaded {len(PROMPTS)} prompts from {prompt_file}")

    # 模型
    model, tokenizer = load_model_and_tokenizer(model_name, dtype=DTYPE, device_map=DEVICE_MAP)

    # 轴
    u_main = load_axis(MAIN_AXIS_PATH, device=model.device, dtype=DTYPE) if os.path.exists(MAIN_AXIS_PATH) else None
    if u_main is None:
        print(f"[WARN] main axis not found: {MAIN_AXIS_PATH}")

    usage_axes = {}
    for name, p in AXIS_USAGE_PATHS.items():
        if os.path.exists(p):
            usage_axes[name] = load_axis(p, device=model.device, dtype=DTYPE)
            if u_main is not None:
                print(f"[AXIS={name}] Cos(raw, main)={cosine(usage_axes[name], u_main):+.4f}")
        else:
            print(f"[WARN] skip '{name}': not found -> {p}")

    # 输出文件
    out_csv = build_output_path(output_dir, output_base, add_ts)
    with open(out_csv, "w", newline="", encoding="utf-8") as fout:
        writer = csv.writer(fout)
        writer.writerow([
            "prompt_id", "prompt",
            "exp_mode",          # main_only / usage_only / main_plus_usage
            "usage",             # usage 名称或 '—'
            "ortho",             # True/False
            "alpha_main",        # 主轴强度
            "beta_usage",        # usage 强度
            "output"
        ])

        # ===== (A) main-only =====
        if run_main_only:
            if u_main is None:
                print("[INFO] main-only skipped: main axis not found.")
            else:
                print("\n===== (A) MAIN-ONLY =====")
                for alpha in alphas_main_only:
                    print(f"[main-only] alpha={alpha}σ")
                    for i, prompt in enumerate(PROMPTS, 1):
                        try:
                            txt = generate_one(model, tokenizer, prompt,
                                               layer_to_steer,
                                               alpha_main=alpha, u_main_vec=u_main,
                                               beta_usage=0.0, u_usage_vec=None,
                                               max_new_tokens=max_new_tokens,
                                               temperature=temperature, top_p=top_p,
                                               repetition_pen=rep_pen,
                                               no_repeat_ngram=no_repeat)
                        except Exception as e:
                            txt = f"[ERROR] {type(e).__name__}: {e}"
                        writer.writerow([i, prompt, "main_only", "—", False, alpha, 0.0, txt.strip()])

        # ===== (B) usage-only (raw / ortho) =====
        if run_usage_only:
            print("\n===== (B) USAGE-ONLY =====")
            for usage_name in USAGES_TO_TRY:
                if usage_name not in usage_axes:
                    continue
                u_raw = usage_axes[usage_name]
                for mode in modes_for_usage:
                    u_use = usage_axis(mode, u_raw, u_main)
                    print(f"[usage-only:{usage_name}:{mode}] betas={betas_usage_only}")
                    for beta in betas_usage_only:
                        for i, prompt in enumerate(PROMPTS, 1):
                            try:
                                txt = generate_one(model, tokenizer, prompt,
                                                   layer_to_steer,
                                                   alpha_main=0.0, u_main_vec=u_main,
                                                   beta_usage=beta, u_usage_vec=u_use,
                                                   max_new_tokens=max_new_tokens,
                                                   temperature=temperature, top_p=top_p,
                                                   repetition_pen=rep_pen,
                                                   no_repeat_ngram=no_repeat)
                            except Exception as e:
                                txt = f"[ERROR] {type(e).__name__}: {e}"
                            writer.writerow([i, prompt, "usage_only", usage_name, (mode=="ortho"), 0.0, beta, txt.strip()])

        # ===== (C) main + usage (raw / ortho) =====
        if run_main_plus_usage:
            if u_main is None:
                print("[INFO] main+usage skipped: main axis not found.")
            else:
                print("\n===== (C) MAIN + USAGE =====")
                for usage_name in USAGES_TO_TRY:
                    if usage_name not in usage_axes:
                        continue
                    u_raw = usage_axes[usage_name]
                    for mode in modes_for_usage:
                        u_use = usage_axis(mode, u_raw, u_main)
                        print(f"[main+usage:{usage_name}:{mode}] alphas={alphas_for_combo}, betas={betas_for_combo}")
                        for alpha in alphas_for_combo:
                            for beta in betas_for_combo:
                                for i, prompt in enumerate(PROMPTS, 1):
                                    try:
                                        txt = generate_one(model, tokenizer, prompt,
                                                           layer_to_steer,
                                                           alpha_main=alpha, u_main_vec=u_main,
                                                           beta_usage=beta, u_usage_vec=u_use,
                                                           max_new_tokens=max_new_tokens,
                                                           temperature=temperature, top_p=top_p,
                                                           repetition_pen=rep_pen,
                                                           no_repeat_ngram=no_repeat)
                                    except Exception as e:
                                        txt = f"[ERROR] {type(e).__name__}: {e}"
                                    writer.writerow([i, prompt, "main_plus_usage", usage_name, (mode=="ortho"), alpha, beta, txt.strip()])

    print(f"\n[OK] All done. Results saved to: {out_csv}")

if __name__ == "__main__":
    main()
