import os
import torch
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import load_json

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def register_layer_hook(model, layer_idx, activation_cache, key="residual"):
    """
    Register a forward hook to capture activations from a given layer.
    """
    layer_module = model.model.layers[layer_idx]

    def hook_fn(module, input, output):
        activation_cache[key] = output[0].detach() if isinstance(output, tuple) else output.detach()

    handle = layer_module.register_forward_hook(hook_fn)
    return handle


def compute_vector(text, model, tokenizer, activation_cache, key="residual", device="cpu"):
    """
    Compute the averaged hidden state vector for a given text input.
    """
    activation_cache.clear()
    inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        _ = model(**inputs)

    residual = activation_cache[key]
    mask = inputs["attention_mask"].unsqueeze(-1)
    avg = (residual * mask).sum(1) / mask.sum(1)
    return avg.squeeze(0)


def main():
    parser = argparse.ArgumentParser(description="Build reflection vector dataset from raw/reflection answers.")
    parser.add_argument("--input", type=str, required=True, help="Path to input JSON file.")
    parser.add_argument("--output", type=str, required=True, help="Path to save output .pt file.")
    parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model name or path.")
    parser.add_argument("--layer", type=int, default=16, help="Layer index to extract activations from.")
    args = parser.parse_args()

    # === Load model and tokenizer ===
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModelForCausalLM.from_pretrained(args.model, output_hidden_states=True).to(device).eval()
    tokenizer.pad_token = tokenizer.eos_token

    activation_cache = {}
    handle = register_layer_hook(model, args.layer, activation_cache)

    # === Load dataset ===
    raw_data = load_json(args.input)
    dataset = []

    # === Build reflection vectors ===
    for item in tqdm(raw_data, desc="Building vectors"):
        q = item["question"]
        vec_ref = compute_vector(item["reflection"], model, tokenizer, activation_cache, device=device)
        vec_dir = compute_vector(item["raw_answer_response"], model, tokenizer, activation_cache, device=device)
        vec = (vec_ref - vec_dir).detach()
        vec = vec / vec.norm()
        dataset.append({"question": q, "vector": vec.cpu()})

    handle.remove()

    # === Save dataset ===
    torch.save(dataset, args.output)
    print(f"Saved {len(dataset)} samples to {args.output}")


if __name__ == "__main__":
    main()