#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Main-axis steering (no usage). Generates text under multiple alpha strengths and
saves all generations to a TSV for Appendix.

Model: meta-llama/Meta-Llama-3-8B-Instruct
Layer: human-indexed L=14 (we hook layer index 13 internally)
Axis : sentiment main axis (unit-normalized) at that layer
"""

import os, csv, random, argparse
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# --------------------- Utils ---------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def _normalize(v: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return v / (v.norm() + eps)

def load_axis(path: str, device, dtype=torch.float16) -> torch.Tensor:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Main axis not found: {path}")
    v = torch.tensor(np.load(path), dtype=dtype, device=device)
    if v.ndim > 1:  # sometimes saved as (1, H)
        v = v[0]
    return _normalize(v)

def register_main_hook(model, layer_idx_human: int, alpha: float, u_main_vec: torch.Tensor):
    """
    Inject on the last-token hidden state at the given layer:
       h' = h + alpha * RMS(h) * u_main
    layer_idx_human is 1-based (e.g., 14 -> physical 13)
    """
    phys_idx = layer_idx_human - 1
    def hook_fn(module, inputs, outputs):
        # outputs is a tuple (hidden_states, ...). First element is [B, T, H]
        hidden = outputs[0]
        h_last = hidden[:, -1, :]                             # [B,H]
        rms    = h_last.pow(2).mean(dim=-1, keepdim=True).sqrt()
        delta  = alpha * rms * u_main_vec.to(h_last.device, dtype=h_last.dtype)
        hidden = hidden.clone()
        hidden[:, -1, :] = h_last + delta
        return (hidden,) + outputs[1:]
    return model.model.layers[phys_idx].register_forward_hook(hook_fn)

# --------------------- Main ---------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", default="meta-llama/Meta-Llama-3-8B-Instruct")
    ap.add_argument("--layer", type=int, default=14, help="human-indexed layer to steer")
    ap.add_argument("--axis_path", required=True,
                    help="path to sentiment main axis .npy for the chosen layer")
    ap.add_argument("--alphas", type=float, nargs="+", default=[0.0],
                    help="List of alpha values for steering (space separated, in units of sigma)")

    ap.add_argument("--prompt", default="A person waits in a hospital corridor for test results. Write ~60 words in third person, present tense.")
    ap.add_argument("--max_new_tokens", type=int, default=160)
    ap.add_argument("--temperature", type=float, default=0.1)
    ap.add_argument("--top_p", type=float, default=0.9)
    ap.add_argument("--repeats", type=int, default=1)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--out_tsv", default="outputs/steer_appendix/steer_main_only.tsv")
    args = ap.parse_args()

    set_seed(args.seed)
    os.makedirs(os.path.dirname(args.out_tsv), exist_ok=True)

    print(f"[Config] model={args.model}  L={args.layer}  axis={args.axis_path}")
    print(f"[Config] alphas={args.alphas}  prompt='{args.prompt}'")

    # Load model & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(
        args.model, torch_dtype=torch.float16, device_map="auto"
    )
    device = model.device

    # Load and normalize main axis
    u_main = load_axis(args.axis_path, device=device, dtype=torch.float16)

    # Prepare alphas (interpreted as multiples of sigma=RMS)
    alpha_list = args.alphas

    # Generate across the grid and save
    with open(args.out_tsv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f, delimiter="\t")
        w.writerow(["prompt", "layer", "alpha_sigma", "sample_id", "text"])

        for a in alpha_list:
            print(f"\n=== alpha = {a}σ @ L{args.layer} ===")
            for r in range(args.repeats):
                # register hook
                hook = register_main_hook(model, args.layer, alpha=a, u_main_vec=u_main)

                inputs = tokenizer(args.prompt, return_tensors="pt").to(device)
                out = model.generate(
                    **inputs,
                    max_new_tokens=args.max_new_tokens,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    do_sample=True,
                    repetition_penalty=1.12,
                    no_repeat_ngram_size=3,
                    pad_token_id=tokenizer.pad_token_id,
                )
                hook.remove()

                gen_tokens = out[0][inputs["input_ids"].size(1):]
                text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
                print(f"[sample {r+1}] {text}\n")
                w.writerow([args.prompt, args.layer, a, r, text])

    print(f"\n[OK] Saved generations to: {args.out_tsv}")

if __name__ == "__main__":
    torch.set_grad_enabled(False)
    main()
