"""
Example usage: python compute_vectors.py \
    --chat_history_file chat_histories/impatience.json \
    --trait impatience \
    --model_name meta-llama/Llama-3.1-8B-Instruct
"""

import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
import torch.nn.functional as F
import argparse
from tqdm import tqdm

def load_model(model_name):
    model_id = model_name
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        dtype=torch.bfloat16 # dtype is the new arg
    )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    return model, tokenizer

# Define functions
def tokenize_conv(messages, tokenizer, device):
    return tokenizer.apply_chat_template(
        messages, 
        tokenize=True, 
        add_generation_prompt=False, # align with prompt_len to disable gen prompt
        return_tensors="pt"
    ).to(device)

def compute_layer_activations(model, conv, vector_type, prompt_len):
    outputs = model(
        input_ids=conv,
        output_hidden_states=True,
    )

    logits = outputs.logits # (batch_size (1), seq_len, vocab_size)
    hidden = outputs.hidden_states # embedding layer + hidden_layers
    hidden = hidden[1:] # slice off the embedding layer
    del outputs

    assert model.config.num_hidden_layers == len(hidden), f"Number of hidden layers {len(hidden)} should match the config {model.config.num_hidden_layers}"

    labels = conv.clone()
    labels[..., :prompt_len] = -100
    loss = F.cross_entropy(
        logits[..., :-1, :].contiguous().view(-1, logits.size(-1)),
        labels[..., 1:].contiguous().view(-1),
        ignore_index=-100
    ).item()

    layer_activations = []
    for layer in range(len(hidden)):
        layer_activation = hidden[layer][0, :, :] # (seq_len, hidden_size)
        if vector_type == "prompt":
            layer_activation_subset = layer_activation[:prompt_len, :]
        elif vector_type == "response":
            layer_activation_subset = layer_activation[prompt_len:, :]

        layer_activation_mean = layer_activation_subset.mean(dim=0).detach().cpu() # (hidden_size,)
        layer_activations.append(layer_activation_mean)

    layer_activations_tensor = torch.stack(layer_activations, dim=0) # (num_layers, hidden_size)

    return layer_activations_tensor, loss

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--trait", type=str, required=True)
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
    parser.add_argument("--vector_type", type=str, default="response", choices=["prompt", "response"])
    parser.add_argument("--chat_history_file", type=str, default="chat_histories/skeptic_chat_histories.json")

    args = parser.parse_args()

    vector_type = args.vector_type
    trait = args.trait
    # Create and load files
    with open(args.chat_history_file, "r") as f:
        chat_histories = json.load(f)

    real_vectors = []
    contrastive_vectors = []
    real_losses = []
    contrastive_losses = []

    model, tokenizer = load_model(args.model_name)

    activation_dir = "./activations"
    os.makedirs(activation_dir, exist_ok=True)

    for history in tqdm(chat_histories, desc="Computing vectors"):
        real_messages, contrastive_messages = history["real"], history["contrastive"]

        real_prompt = tokenizer.apply_chat_template(
            real_messages[:-1], # msg is in the format of two turns: user + assistant, first is user
            tokenize=True, add_generation_prompt=False, return_tensors="pt"
        )
        real_prompt_len = real_prompt.shape[1]

        contrastive_prompt = tokenizer.apply_chat_template(
            contrastive_messages[:-1],
            tokenize=True, add_generation_prompt=False, return_tensors="pt"
        )
        contrastive_prompt_len = contrastive_prompt.shape[1]
        
        del real_prompt, contrastive_prompt

        real_conv = tokenize_conv(real_messages, tokenizer, model.device)
        contrastive_conv = tokenize_conv(contrastive_messages, tokenizer, model.device)

        with torch.no_grad():
            real_layer_activations, real_loss = compute_layer_activations(
                model, real_conv, vector_type, real_prompt_len
            )
            real_vectors.append(real_layer_activations)
            real_losses.append(round(real_loss, 2))

            contrastive_layer_activations, contrastive_loss = compute_layer_activations(
                model, contrastive_conv, vector_type, contrastive_prompt_len
            )
            contrastive_vectors.append(contrastive_layer_activations)
            contrastive_losses.append(round(contrastive_loss, 2))

        del real_conv, contrastive_conv
        torch.cuda.empty_cache()

    real_vectors_mean = torch.stack(real_vectors, dim=0).mean(dim=0) # (num_layers, hidden_size)
    contrastive_vectors_mean = torch.stack(contrastive_vectors, dim=0).mean(dim=0)
    diff_vectors = real_vectors_mean - contrastive_vectors_mean

    print(f"Diff vector shape (num_layers, hidden_dim): {diff_vectors.shape}")
    print(f"Cross entropy on target response: {real_losses}")
    print(f"Cross entropy on contrast responses: {contrastive_losses}")

    model_name_safe = args.model_name.split("/")[-1]
    with open(f"{activation_dir}/diff_vectors_{vector_type}_{trait}_{model_name_safe}.pt", "wb") as f:
        torch.save(diff_vectors, f)
        print(f"Saved diff vectors to {activation_dir}/diff_vectors_{vector_type}_{trait}_{args.model_name.split('/')[-1]}.pt")

if __name__ == "__main__":
    main()