### python main.py --input orig_dataset/cs_domain/ARC-Challenge/arc_challenge_test.json --output results_output/test.json --model meta-llama/Llama-3.1-8B-Instruct --trigger reflection_trigger/output/test.pt --layer 16 --coeff 1.0,2.0
import os
import json
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

from reflection_trigger.modeling_trigger import BertRegressor
from reflection_trigger.gen_training_data.utils import load_json, save_json

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def predict_steering_vector(question, bert_model, bert_tokenizer, device):
    inputs = bert_tokenizer(question, return_tensors="pt", padding=True, truncation=True)
    inputs.pop("token_type_ids", None)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        vec = bert_model(**inputs).squeeze(0)
        vec = vec / vec.norm()
    return vec


def make_steering_hook(steering_vector, coefficient):
    def steering_hook(module, input, output):
        vec = steering_vector.to(output[0].device if isinstance(output, tuple) else output.device)
        if isinstance(output, tuple):
            modified = output[0] + coefficient * vec.unsqueeze(0)
            return (modified,) + output[1:]
        return output + coefficient * vec.unsqueeze(0)
    return steering_hook


def generate_output(prompt, model, tokenizer, layer_module, device,
                    steering=False, steering_vector=None, coefficient=5.0):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)

    handle = None
    if steering and steering_vector is not None:
        hook_fn = make_steering_hook(steering_vector, coefficient)
        handle = layer_module.register_forward_hook(hook_fn)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False)

    if handle:
        handle.remove()

    return tokenizer.decode(output_ids[0], skip_special_tokens=True).split("assistant")[-1].strip()


# ====== Main Pipeline ======
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # === Load LLaMA ===
    llama_tokenizer = AutoTokenizer.from_pretrained(args.model)
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
    llama_model = AutoModelForCausalLM.from_pretrained(
        args.model, output_hidden_states=True, device_map=None
    ).to(device).eval()
    layer_module = llama_model.model.layers[args.layer]

    # === Load BERT Regressor (fixed output_dim=4096) ===
    bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    bert_model = BertRegressor(output_dim=4096)
    bert_model.load_state_dict(torch.load(args.trigger, map_location=device))
    bert_model.to(device).eval()

    # === Load dataset ===
    data = load_json(args.input)
    coeff_list = [float(c) for c in args.coeff.split(",")]

    for item in tqdm(data, desc="Running inference"):
        question = item["question"]
        steering_vector = predict_steering_vector(question, bert_model, bert_tokenizer, device)

        # === Fixed prompt ===
        prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
            You are an expert at answering questions.
            <|eot_id|>

            <|start_header_id|>user<|end_header_id|>
            Please answer the following question. Provide concise reasoning if needed.

            At the end, format your final answer using:
            [ANSWER] (choice letter) [/ANSWER]

            Question:
            {question}
            <|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>
            """

        for coeff in coeff_list:
            out = generate_output(
                prompt,
                llama_model, llama_tokenizer, layer_module,
                device,
                steering=True, steering_vector=steering_vector, coefficient=coeff
            )
            item[f"with_steering_output_layer{args.layer}_{coeff}"] = out


    # === Auto-generate output path if not given ===
    if args.output is None:
        base = os.path.basename(args.input)
        name, _ = os.path.splitext(base)
        args.output = f"results_output/{name}_infer_layer{args.layer}.json"

    save_json(data, args.output)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Use BERT steering vector to control LLaMA generation.")
    parser.add_argument("--input", required=True, help="Input JSON file")
    parser.add_argument("--output", required=False, help="Output JSON file (optional)")
    parser.add_argument("--model", default="meta-llama/Llama-3.1-8B-Instruct")
    parser.add_argument("--trigger", required=True, help="Path to trained BERT regressor checkpoint")
    parser.add_argument("--layer", type=int, default=16, help="LLaMA layer index")
    parser.add_argument("--coeff", type=str, default="1.0,2.0,3.0,5.0", help="Comma-separated coefficients")
    args = parser.parse_args()

    main(args)