"""Extract per-layer Q/K for a dataset and save to data/<model>/<dataset>/.

Example:
  python3 extract_attention.py \
    --model_key llama-3.1 \
    --dataset narrativeqa \
    --exp_name extract_attention \
    --max_samples 1

This will patch all attention layers to dump Q/K during the prefill pass of a
single generation call.
"""

from __future__ import annotations

import argparse
import json
import os
from typing import List

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils.helpers import seed_everything
from src.utils.prompting import build_chat
from src.analysis.extractor import patch_model_for_extraction, set_extraction_dir_and_reset


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Extract per-layer Q/K from a model")
    parser.add_argument("--model_key", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--exp_name", type=str, default="extract_attention")
    parser.add_argument("--max_samples", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()




def main() -> None:
    args = parse_args()
    seed_everything(args.seed)

    base_dir = os.path.dirname(os.path.abspath(__file__))
    cfg_dir = os.path.join(base_dir, "configs")
    model2path = json.load(open(os.path.join(cfg_dir, "model2path.json"), "r"))
    dataset2prompt = json.load(open(os.path.join(cfg_dir, "dataset2prompt.json"), "r"))

    model_path = model2path[args.model_key]
    model_name = args.model_key
    dataset_name = args.dataset

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
    model.eval()

    root_save_dir = os.path.join(base_dir, "data", model_name, dataset_name, args.exp_name)
    os.makedirs(root_save_dir, exist_ok=True)
    patch_model_for_extraction(model, root_save_dir)

    data = load_dataset("THUDM/LongBench", dataset_name, split="test")
    prompt_format = dataset2prompt[dataset_name]

    # Run just a few samples to trigger prefill
    for idx, json_obj in enumerate(data):
        if idx >= args.max_samples:
            break
        # Create per-sample subfolder
        sample_dir = os.path.join(root_save_dir, f"sample_{idx:04d}")
        set_extraction_dir_and_reset(model, sample_dir)
        prompt = prompt_format.format(**json_obj)
        prompt = build_chat(tokenizer, prompt, model_name)
        # Use a short generation to ensure prefill occurs
        messages = [{"role": "user", "content": prompt}]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
        inputs = tokenizer([text], return_tensors="pt").to(model.device)
        _ = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=1,
            do_sample=False,
        )
        print(f"Processed sample {idx} -> {sample_dir}")

    print(f"Saved Q/K per layer under {root_save_dir}")


if __name__ == "__main__":
    main()


