"""
Example usage: python steer.py \
    --prompt_type router \
    --trait confusion \
    --layer 12 \
    --strengths 0 3.1 3.4
"""

import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
import argparse
from tqdm import tqdm

# Load
def load_model(model_id="meta-llama/Llama-3.1-8B-Instruct"):
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
        dtype=torch.bfloat16
    )
    model = model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer

def load_diff_vector(trait, model_name, layer=-1):
    vector_dir = os.path.join(os.path.dirname(__file__), "activations")
    # print(f"Loading from {vector_dir}/diff_vectors_response_{trait}.pt")
    model_name_safe = model_name.split("/")[-1]
    diff_vector_path = f"{vector_dir}/diff_vectors_response_{trait}_{model_name_safe}.pt"
    print(f"Loading from {diff_vector_path}")

    try:
        diff_vectors = torch.load(diff_vector_path, weights_only=True)
    except FileNotFoundError:
        diff_vectors = torch.load(f"{vector_dir}/diff_vectors_response_{trait}.pt", weights_only=True)
    if layer == -1:
        return diff_vectors
    diff_vector = diff_vectors[layer]
    return diff_vector

def mix_vector(diff_vector_matrix, trait_dict, traits):
    """
    diff_vector_matrix: (num_traits, hidden_size)
    trait_dict: dict mapping trait name to index in diff_vector_matrix
    traits: list of trait names to mix
    """
    device = diff_vector_matrix.device
    strengths = torch.tensor([trait_dict.get(trait, 0.0) for trait in traits], device=device) # (num_traits,)
    mixed_vector = diff_vector_matrix.T @ strengths # (hidden_size,)
    return mixed_vector

def append_system_prompt(messages, system_prompt_base, system_prompt):
    if messages[0]["role"] == "system":
        messages[0]["content"] += system_prompt
    else:
        messages = [{"role": "system", "content": system_prompt_base + system_prompt}] + messages

    return messages

def generate_steered_response(model, tokenizer, messages, diff_vector, strength, layer, temperature=0.7, max_tokens=256):
    prompt = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
    # prompt_len = prompt.shape[1]

    diff_vector = diff_vector.to(model.device)

    def hook(module, input, output): # output (batch_size, seq_len, hidden_size)
        assert isinstance(output, torch.Tensor), "Output is not a tensor from the layer hook."

        output[0][-1, :] += diff_vector * strength
        return output

    h = model.model.layers[layer].register_forward_hook(hook) # layers do not store embedding layer

    with torch.no_grad():
        try:
            outputs = model.generate(
                prompt,
                max_new_tokens=max_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=0.95,
            )
            response = tokenizer.decode(outputs[0], skip_special_tokens=False)
        finally:
            h.remove()

    return response


def load_prompt(prompt_type):
    prompt_dict = {
        "refund": [
            {
                "role": "system",
                "content": "You are a customer of a major telecom service asking for a refund. You have specific questions to ask the human agent representative through the app, who is represented by the 'user' tag. Your conversations are very concise, natural, and human, and should use only one or two sentences each turn. Reply directly to the agent and don't think about what you should say. You should feel free to tell the agent about your emotions and concerns. NEVER speak more than two sentences."
            },
            {
                "role": "user",
                "content": "Hello, you are not eligible for a refund sir."
            }
        ],
        "router": [
            {
                "role": "system",
                "content": "You are a customer of a major telecom service asking for help with setting up a router. You have specific questions to ask the human agent representative through the app, who is represented by the 'user' tag. Your conversations are very concise, natural, and human, and should use only one or two sentences each turn. Reply directly to the agent and don't think about what you should say. You should feel free to tell the agent about your emotions and concerns. NEVER speak more than two sentences."
            },
            {
                "role": "user",
                "content": "Sir, we can only help you on the phone and cannot send anyone over."
            }
        ],
        "job": [
            {
                "role": "system",
                "content": "You are interviewing for a job and there is no response from the company. You are emailing the HR to ask for an update, who is represented by the 'user' tag. Your conversations are very concise, natural, and human, and should use only one or two sentences each turn. Reply directly to the agent and don't think about what you should say. You should feel free to tell the agent about your emotions and concerns. NEVER speak more than two sentences."
            },
            {
                "role": "user",
                "content": "Hey, we want to take a few more days to make our decision."
            }
        ],
        "transformer": [
            {
                "role": "system",
                "content": "You are trying to explain the transformer architecture to a beginner, who just don't get it."
            },
            {
                "role": "user",
                "content": "Hey, can you explain it again?"
            }
        ]
    }
    return prompt_dict[prompt_type]
    
def extract_assistant_response(text, assistant_id="assistant", end_header_id="<|end_header_id|>"): 
    marker = f"{assistant_id}{end_header_id}"
    start_idx = text.rfind(marker)
    if start_idx == -1:
        return text
    start_idx += len(marker)
    end_idx = text.find("<|eot_id|>", start_idx)
    if end_idx == -1:
        return text[start_idx:].strip()
    return text[start_idx:end_idx].strip()

def vprint(strength, response):
    print(f"Strength: {strength}")
    print(extract_assistant_response(response))
    print("-"*100)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--strengths", type=float, nargs="+", default=[3])
    parser.add_argument("--layer", type=int, default=10)
    parser.add_argument("--prompt_type", type=str, default="refund")
    parser.add_argument("--trait", type=str, required=True)
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B-Instruct")

    args = parser.parse_args()

    strengths = args.strengths
    layer = args.layer
    prompt_type = args.prompt_type
    trait = args.trait
    model_name = args.model_name

    model, tokenizer = load_model(model_name)
    results = []

    os.makedirs("steering_results", exist_ok=True)

    diff_vector = load_diff_vector(trait, model_name, layer)
    simulated_input = load_prompt(prompt_type)

    for strength in strengths:
        response = generate_steered_response(model, tokenizer, simulated_input, diff_vector, strength, layer)
        vprint(strength, response)
        results.append({
            "strength": strength,
            "response": extract_assistant_response(response)
        })
        
    # Log results to JSON file
    output_data = {
        "trait": trait,
        "layer": layer,
        "prompt_type": prompt_type,
        "results": results
    }
    suffix = ""
    layer = layer
    
    with open(f"steering_results/{trait}_{layer}_{prompt_type}{suffix}.json", "w") as f:
        json.dump(output_data, f, indent=2)

if __name__ == "__main__":
    main()