#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Steer along usage axes (genre/tone/context/topic) and along the main sentiment axis
after removing usage components (main⊥usage), all at a single layer.

- Model: meta-llama/Meta-Llama-3-8B-Instruct (can be changed as needed)
- Layer: human-index 14 (internally subtracted by 1)
- Axes: load your trained .npy direction vectors
"""

import os
import random
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# ================== Basic configuration ==================
model_name       = "meta-llama/Meta-Llama-3-8B-Instruct"
layer_to_steer   = 14            # Human index L14 (hook will subtract 1)
temperature      = 0.1
top_p            = 0.9
max_new_tokens   = 160
seed             = 42

# prompt = "The meeting lasted for two hours. Write about this situation in 60 words."
prompt = "A person waits in a hospital corridor for test results. Write ~60 words in third person, present tense."

# Paths for main axis & usage axes
MAIN_AXIS_PATH = "experiments/exp_main_axis/Meta-Llama-3-8B-Instruct_main/sentiment_axis_L14.npy"
AXIS_USAGE_PATHS = {
    "genre":   "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_genre/sentiment_axis_L14.npy",
    "tone":    "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_tone/sentiment_axis_L14.npy",
    "context": "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_contextual/sentiment_axis_L14.npy",
    "topic":   "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct_topic/sentiment_axis_L14.npy",
}
USAGES_TO_TRY   = ["genre", "tone", "context", "topic"]

# Steering strength
# ALPHAS_FOR_MAIN = [-15, 15]      # main vs main⊥usage (β=0)
ALPHAS_FOR_MAIN = [-30, -15, 0, 15, 30]  # α·σ grid
BETA_LIST       = [-15, 15]      # usage-only (α=0)

REPEATS_PER     = 1

# ================== Utility functions ==================
def set_seed(s: int):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)

def _normalize(v: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return v / (v.norm() + eps)

def _proj(u: torch.Tensor, v_unit: torch.Tensor) -> torch.Tensor:
    """Projection of u onto the unit vector v_unit (same direction as v_unit)."""
    return (u * v_unit).sum() * v_unit

def orthogonalize_to(vec: torch.Tensor, bases: list) -> torch.Tensor:
    """Perform Gram–Schmidt orthogonalization of vec against a set of already normalized bases, then normalize."""
    w = vec
    for b in bases:
        w = w - _proj(w, b)
    return _normalize(w)

def load_axis(path: str, device, dtype=torch.float16) -> torch.Tensor:
    v = torch.tensor(np.load(path), dtype=dtype, device=device)
    return _normalize(v)

def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(
        torch.nn.functional.cosine_similarity(
            a.unsqueeze(0).to(torch.float32),
            b.unsqueeze(0).to(torch.float32)
        ).item()
    )

# ================== Load model ==================
set_seed(seed)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)
device = model.device

# ================== Load axes ==================
u_main = load_axis(MAIN_AXIS_PATH, device=device, dtype=torch.float16) if os.path.exists(MAIN_AXIS_PATH) else None

usage_axes_raw = {}
for name, path in AXIS_USAGE_PATHS.items():
    if os.path.exists(path):
        usage_axes_raw[name] = load_axis(path, device=device, dtype=torch.float16)
    else:
        print(f"[WARN] usage axis missing: {name} -> {path}")

# ================== Construct usage subspace & main⊥usage ==================
usage_basis_f32 = [_normalize(v.to(torch.float32)) for v in usage_axes_raw.values()]
# GS orthogonal group
usage_basis_ortho_f32 = []
for b in usage_basis_f32:
    b_ortho = orthogonalize_to(b, usage_basis_ortho_f32)
    usage_basis_ortho_f32.append(b_ortho)

u_main_perp = None
if u_main is not None and len(usage_basis_ortho_f32) > 0:
    u_main_unit_f32  = _normalize(u_main.to(torch.float32))
    u_main_perp_f32  = orthogonalize_to(u_main_unit_f32, usage_basis_ortho_f32)
    u_main_perp      = u_main_perp_f32.to(dtype=u_main.dtype, device=u_main.device)

    proj = u_main_unit_f32 - u_main_perp_f32
    r2   = float((proj.norm().item() ** 2) / (u_main_unit_f32.norm().item() ** 2 + 1e-12))
    print(f"[main⊥usage] cos(main, main⊥usage)={cosine(u_main, u_main_perp):+.4f}  R2_usage→main={r2:.2f}")
else:
    print("[main⊥usage] skipped: main axis missing or no usage basis.")

# ================== Construct orthogonalized usage axes (remove main axis component) ==================
usage_axes_ortho = {}
if u_main is not None:
    u_main_f32 = _normalize(u_main.to(torch.float32))
    for k, v in usage_axes_raw.items():
        v_f32 = _normalize(v.to(torch.float32))
        v_ortho_f32 = orthogonalize_to(v_f32, [u_main_f32])
        usage_axes_ortho[k] = v_ortho_f32.to(dtype=v.dtype, device=v.device)
else:
    usage_axes_ortho = usage_axes_raw  # fallback if no main axis

# Print cosine similarities with main axis
if u_main is not None:
    for k in USAGES_TO_TRY:
        if k in usage_axes_raw:
            c_raw   = cosine(usage_axes_raw[k], u_main)
            c_ortho = cosine(usage_axes_ortho[k], u_main)
            print(f"[AXIS={k:7s}] cos(raw, main)={c_raw:+.4f}  cos(ortho, main)={c_ortho:+.4f}")

# ================== Register generic hook ==================
def register_hook(model, layer_idx_human, alpha_main, u_main_vec, beta_usage, u_usage_vec):
    """
    Inject into the last token hidden state:
      h' = h + alpha_main * RMS(h) * u_main + beta_usage * RMS(h) * u_usage
    Either term is ignored if its coefficient = 0.
    layer_idx_human is 1-based.
    """
    phys_idx = layer_idx_human - 1  # convert to 0-based index
    def hook_fn(module, inputs, outputs):
        hidden = outputs[0]        # [B,T,H]
        h_last = hidden[:, -1, :]  # [B,H]
        rms    = h_last.pow(2).mean(dim=-1, keepdim=True).sqrt()

        delta = torch.zeros_like(h_last)
        if u_main_vec is not None and alpha_main != 0.0:
            delta = delta + alpha_main * rms * u_main_vec.to(h_last.device, dtype=h_last.dtype)
        if u_usage_vec is not None and beta_usage != 0.0:
            delta = delta + beta_usage * rms * u_usage_vec.to(h_last.device, dtype=h_last.dtype)

        hidden = hidden.clone()
        hidden[:, -1, :] = h_last + delta
        return (hidden,) + outputs[1:]
    return model.model.layers[phys_idx].register_forward_hook(hook_fn)

# ================== Generation functions ==================
@torch.inference_mode()
def generate_with_main(alpha_main: float, use_perp: bool = False) -> str:
    if u_main is None:
        return "[ERROR] main axis not available."
    u_m = u_main_perp if (use_perp and u_main_perp is not None) else u_main
    h = register_hook(model, layer_to_steer, alpha_main=alpha_main, u_main_vec=u_m,
                      beta_usage=0.0, u_usage_vec=None)

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
        do_sample=True, repetition_penalty=1.12, no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id,
    )
    h.remove()
    gen = out[0][inputs["input_ids"].size(1):]
    return tokenizer.decode(gen, skip_special_tokens=True)

@torch.inference_mode()
def generate_with_usage(usage_name: str, beta: float, use_ortho: bool = True) -> str:
    if usage_name not in usage_axes_raw:
        return f"[ERROR] usage axis '{usage_name}' not available."
    u_use = usage_axes_ortho[usage_name] if use_ortho else usage_axes_raw[usage_name]
    h = register_hook(model, layer_to_steer, alpha_main=0.0, u_main_vec=None,
                      beta_usage=beta, u_usage_vec=u_use)

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
        do_sample=True, repetition_penalty=1.12, no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id,
    )
    h.remove()
    gen = out[0][inputs["input_ids"].size(1):]
    return tokenizer.decode(gen, skip_special_tokens=True)

# ================== Main program ==================
if __name__ == "__main__":
    # A) main vs. main⊥usage (β=0)
    print(f"\n=== MAIN vs MAIN⊥USAGE @ L{layer_to_steer} (β=0) ===")
    for a in ALPHAS_FOR_MAIN:
        for mode in ["main_raw", "main_perp"]:
            use_perp = (mode == "main_perp")
            header = f"[{mode}] alpha={a}σ"
            print("\n" + header)
            for r in range(REPEATS_PER):
                try:
                    txt = generate_with_main(alpha_main=a, use_perp=use_perp)
                except Exception as e:
                    txt = f"[ERROR] {type(e).__name__}: {e}"
                print(f"[sample {r+1}] {txt.strip()}\n")

    # B) usage-only (α=0)
    print(f"\n=== USAGE-only steering @ L{layer_to_steer} (alpha_main=0) ===")
    for usage_name in USAGES_TO_TRY:
        if usage_name not in usage_axes_raw:
            print(f"[WARN] skip '{usage_name}': axis not found.")
            continue
        if u_main is not None:
            print(f"[AXIS={usage_name}] cos(raw, main)={cosine(usage_axes_raw[usage_name], u_main):+.4f}  "
                  f"cos(ortho, main)={cosine(usage_axes_ortho[usage_name], u_main):+.4f}")
        for beta in BETA_LIST:
            for mode in ["raw", "ortho"]:
                use_ortho = (mode == "ortho")
                header = f"[{usage_name}:{mode}] beta={beta}σ"
                print("\n" + header)
                for r in range(REPEATS_PER):
                    try:
                        txt = generate_with_usage(usage_name=usage_name, beta=beta, use_ortho=use_ortho)
                    except Exception as e:
                        txt = f"[ERROR] {type(e).__name__}: {e}"
                    print(f"[sample {r+1}] {txt.strip()}\n")
