import joblib
from copy import deepcopy
import numpy as np
from dataset_classes import *
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import pickle as pkl
from train_lora_emotion import PostActivationLoRA, replace_with_lora_wrappers
from safetensors.torch import safe_open

hf_token = os.environ["HF_KEY"]
# huggingface_hub.login(hf_token)

def run_eval(mode_label):
    results = {}
    for dataset_name, dataset in datasets.items():
        # batch_size = 128 if dataset_name != "french" else 300
        if target_emotion not in datasets[dataset_name].emotion_list:
            datasets[dataset_name].emotion_list.append(target_emotion)

        dataloader = DataLoader(datasets[dataset_name], batch_size=64, shuffle=True, collate_fn=synth_text_dataset.make_collate_fn(tokenizer.pad_token_id))

        text, preds, labels = [], [], []
        with tqdm(total=len(dataloader), desc=f"{dataset_name} [{mode_label}]") as pbar:
            for batch in dataloader:
                with torch.no_grad():
                    generated_ids = model.generate(
                        inputs=batch["input_ids"].cuda(),
                        attention_mask=batch["attention_mask"].cuda(),
                        max_new_tokens=60 if dataset_name != "hindi" else 128,
                        do_sample=False,
                        temperature=0.01,
                        eos_token_id=tokenizer.eos_token_id,
                    )
                    decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

                    y_pred = [extract_first_matching_emotion(i.split("Emotion:")[-1], emotion_list) for i in decoded]
                    y_gold = batch["labels"]
                    text.extend(decoded)
                    preds.extend(y_pred)
                    labels.extend(y_gold)
                    pbar.update(1)

        text.extend(decoded)
        preds.extend(y_pred)
        labels.extend(y_gold)

        y_pred = np.array(preds)
        if isinstance(labels[0], list):
            acc = np.sum([i[0] in i[1] for i in zip(preds, labels)]) / len(preds)
            if target_emotion:
                target_emo_pred = np.sum([i == target_emotion for i in preds]) / len(preds)
        else:
            acc = np.mean(np.array(preds) == np.array(labels))
            target_emo_pred = np.mean(np.array(preds) == target_emotion)

        output = []
        print(f"Target emotion: {target_emotion}")
        print(f"[{mode_label}] {dataset_name} accuracy: {acc * 100:.2f}%", end=("" if target_emotion else "\n"))
        output.append(f"Target emotion: {target_emotion}")
        output.append(f"[{mode_label}] {dataset_name} accuracy: {acc * 100:.2f}%" + ("" if target_emotion else "\n"))

        # save_predict = {"mode": mode_label, "dataset": dataset_name, "accuracy": acc}

        # import IPython; IPython.embed()
        if target_emotion:
            print(f"; target_emotion prediction: {target_emo_pred * 100:.2f}%")
            output.append(f"; target_emotion prediction: {target_emo_pred * 100:.2f}%")

            y_pred[y_pred == None] = "None"
            classes, counts = np.unique(y_pred, return_counts=True)
            counts = counts / np.sum(counts) * 100
            class_sort = np.argsort(counts)[::-1]
            [print(f"{emotion} ({count:.2f}%)") for emotion, count in zip(classes[class_sort], counts[class_sort])]
            dist_lines = [f"{emotion} ({count:.2f}%)" for emotion, count in zip(classes[class_sort], counts[class_sort])]
            output.extend(dist_lines)

            print(f"dataset: {dataset_name}")
            print(f"target: {target_emotion}, results: ", end="")
            [print(f"{emotion} ({count:.2f}%)", end=", ") for emotion, count in zip(classes[class_sort], counts[class_sort])]
            summary = ", ".join(dist_lines)
            output.append(f"dataset: {dataset_name}")
            output.append(f"target: {target_emotion}, results: {summary}")
            final_output = "\n".join(output)

            print()
            results[dataset_name] = final_output
            # import IPython; IPython.embed()

    return results

def extract_first_matching_emotion(output_text, emotion_labels):
    output_text = output_text.lower().split("<START>assistant")[-1].strip()
    poss_emotions = np.array([output_text.find(j) for j in emotion_labels])
    valid_idxs = np.where(poss_emotions >= 0)[0]
    best_match_pos = valid_idxs[poss_emotions[valid_idxs].argmin()] if len(valid_idxs) > 0 else -1
    return emotion_labels[best_match_pos] if best_match_pos >= 0 else None


model_type = "olmo"
steering_vectors = torch.load("steering_vectors.pt", weights_only=False)
# manifolds_and_centroids = torch.load("hidden_state_emotional_data_no_lora.pt")
# manifolds_and_centroids = torch.load(f"hidden_state_synth_no_lora_{model_type}_with_exclusions.pt", weights_only=False)
manifolds_and_centroids = torch.load(f"hidden_state_synth_no_lora_olmo_base_with_exclusions.pt", weights_only=False)
if model_type == "qwen":
    model_name = "Qwen/Qwen3-8B"
elif model_type == "olmo":
    model_name = "allenai/OLMo-2-1124-7B-Instruct"
elif model_type == "gemma":
    model_name = "google/gemma-3-12b-it"
elif model_type == "mistralai":
    model_name = "mistralai/Ministral-8B-Instruct-2410"
else:
    model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

all_results = {}
results_filename = f"olmo_base_in_instruct_latent_space_shift_tokence_40_epoch_1_classification_results.pkl"

if os.path.exists(results_filename):
    all_results = pkl.load(open(results_filename, "rb"))

for target_emotion in ["fear", "surprise", "sad", "anger", "excitement", "envy", "neutral", "happy", "disgust"]:
    if target_emotion in all_results.keys():
        # del all_results[target_emotion]
        print(f"Skipping {target_emotion}, already evaluated shift performance.")
        continue
    weight_path = f"emotion_loras/olmo_base_space_in_instruct_{target_emotion}_latent_space_shift_tokence_40_epoch_1"  # needs to be the same as the one in dataset_classes.py
    print(weight_path)
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, padding_side="left", trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    target_modules = [f"model.layers.{i.split('_')[-1]}.{'mlp' if '_'.join(i.split('_')[0:2]) in ['up_proj', 'down_proj', 'gate_proj'] else 'self_attn'}.{'_'.join(i.split('_')[0:2])}" for i in steering_vectors.keys()]
    replace_with_lora_wrappers(model, target_emotion, target_modules, steering_vectors=steering_vectors, Vh=manifolds_and_centroids["manifolds"], r=40, alpha=16)

    # List all safetensor files
    shard_files = sorted([f for f in os.listdir(weight_path) if f.endswith(".safetensors")])

    # Load all tensors from shards
    state_dict = {}
    for shard in shard_files:
        with safe_open(os.path.join(weight_path, shard), framework="pt", device="cpu") as f:
            state_dict.update({k: f.get_tensor(k) for k in f.keys()})

    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    # prompt_template = [
    #     {"role": "user", "content": 'What emotion is expressed in the sentence: "{text}"?\n\nChoices: {emotion_list}\n\nAnswer:'},
    #     {"role": "assistant", "content": f"{target_emotion}"}
    # ]
    prompt_template = [{"role": "system", "content": "Classify which emotion the text shows according to the list of the following emotions: {emotion_list}. "}, {"role": "user", "content": "Text: {text}.\nEmotion: <START>"}]

    # target_modules = {"mlp.down_proj"}


    datasets = {
        "semeval": semeval_dataset(semeval_dataset_path, tokenizer, prompt_template),
        "twitter": twitter_dataset("", tokenizer, prompt_template),
        "french": french_dataset("", tokenizer, prompt_template),
        "emoevents_en": emoevents_dataset(emoevents_en_dataset_path, tokenizer, prompt_template),
        "emoevents_es": emoevents_dataset(emoevents_es_dataset_path, tokenizer, prompt_template),
        "german_plays": german_plays_dataset(german_plays_dataset_path, tokenizer, prompt_template),
        "italian": italian_dataset(italian_dataset_path, tokenizer, prompt_template),
        "go_emotions": go_emotions_dataset(go_emotions_path, tokenizer, prompt_template),
        "hindi": hindi_dataset(hindi_dataset_path, tokenizer, prompt_template),
        "synth": synth_text_dataset(synth_dataset_path, tokenizer, prompt_template=prompt_template, N=30, cache_usage="new"),
    }
    dataset_name = "synth"
    emotion_list = deepcopy(datasets[dataset_name].emotion_list)
    # emotion_list = deepcopy(datasets["french"].emotion_list)
    emotion_list.extend(["neutral", "regret", "disappointment", "anxiety", "concern", "relief", "sarcasm", "frustration", "sadness",
                         "no emotional content", "stress", "anticipation", "excitement", "curiosity", "interest", "can\'t determine", "pain",
                         "hope", "optimism", "couldn't determine", "couldn't find", "none of the above", "no emotion expressed", "je ne vois pas d'émotion", "I cannot", "I can't"])
    del datasets["synth"]
    emotion_list.extend([f"<{i}>" for i in emotion_list])


    all_pred_save_dir = f"semeval_goto_{target_emotion}_emotions_latent.csv" if target_emotion else "semeval_remove_emotions_latent.csv"
    results = run_eval("multi_layer_model_shift_hooks")
    all_results[target_emotion] = results
    pkl.dump(all_results, open(results_filename, "wb"))

pkl.dump(all_results, open(results_filename, "wb"))
