import json
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import umap
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def get_llama_hidden_state(text, tokenizer, model, layer_index=16):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)
        hidden_states = outputs.hidden_states[layer_index + 1]  # +1 because 0th is embedding

    # Average pooling over sequence
    vec = hidden_states.mean(dim=1).squeeze(0)  # shape: [hidden_dim]
    return vec

def vis_umap(
    json_path,
    llama_model_name,
    output_img="llama_umap.png",
    layer_index=-1,
    title="UMAP Embedding"
):
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(llama_model_name,output_hidden_states=True).to(device).eval()

    # Load inference results
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    vecs_0 = []
    vecs_1 = []
    ids = []

    for item in tqdm(data, desc="Extracting LLaMA hidden vectors"):
        text_0 = item.get("no_steering_output", "")
        text_1 = item.get("with_steering_output_1.0", "")

        vec0 = get_llama_hidden_state(text_0, tokenizer, model, layer_index)
        vec1 = get_llama_hidden_state(text_1, tokenizer, model, layer_index)

        vecs_0.append(vec0.cpu().numpy())
        vecs_1.append(vec1.cpu().numpy())
        ids.append(item.get("id", ""))

    all_vecs = np.concatenate([vecs_0, vecs_1], axis=0)
    reducer = umap.UMAP(n_components=2, random_state=42)
    embeddings = reducer.fit_transform(all_vecs)

    emb_0 = embeddings[:len(vecs_0)]
    emb_1 = embeddings[len(vecs_0):]

    # Plotting
    plt.figure(figsize=(12, 8))
    plt.rcParams.update({
    'font.size': 15,         
    'axes.titlesize': 18,     
    'legend.fontsize': 15,     
    'xtick.labelsize': 13,   
    'ytick.labelsize': 13,    
})
    for i in range(len(ids)):
        plt.scatter(emb_0[i, 0], emb_0[i, 1], c='blue', label='non-reflective(coeff=0.0)' if i == 0 else "", alpha=0.6)
        plt.scatter(emb_1[i, 0], emb_1[i, 1], c='red', label='reflective(coeff=1.0)' if i == 0 else "", alpha=0.6)
        plt.plot([emb_0[i, 0], emb_1[i, 0]], [emb_0[i, 1], emb_1[i, 1]], c='gray', alpha=0.3)


    plt.title(f"{title} (Layer {layer_index})")
    plt.xlabel("UMAP Dimension 1", fontsize=14)
    plt.ylabel("UMAP Dimension 2", fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_img, dpi=300)
    plt.show()

# runs
vis_umap(
    json_path="results_output/mmlu_test_infer_v0_layer16.json",
    llama_model_name="meta-llama/Llama-3.1-8B-Instruct",
    output_img="analysis/output/mmlu_umap_layer16.png",
    layer_index=16,
    title="MMLU-Med: UMAP Embedding (Layer 16) of Reflective Responses"
)

# csqa
# vis_umap(
#     json_path="results_output/csqa_val_infer_v0_layer16.json",
#     llama_model_name="meta-llama/Llama-3.1-8B-Instruct",
#     output_img="analysis/output/csqa_umap_layer16.png",
#     layer_index=16,
#     title="CSQA-Val: UMAP Embedding (Layer 16) of Reflective Responses"
# )