from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
import huggingface_hub
from torch.nn import functional as F
import pickle as pkl
import gc
import h5py
import numpy as np
from collections import defaultdict
import random
from datasets import load_dataset
import ast
from dataset_classes import *
import pickle as pkl

hidden_state_data = defaultdict(list)  # {layer_name: [(label, avg)]}

hf_token = os.environ["HF_KEY"]
huggingface_hub.login(hf_token)

# model_name = "meta-llama/Llama-3.1-8B"  # needs to be the same as the one in dataset_classes.py
model_name = "meta-llama/Llama-3.1-8B-Instruct"  # needs to be the same as the one in dataset_classes.py
# model_name = "mistralai/Ministral-8B-Instruct-2410"  # needs to be the same as the one in dataset_classes.py
# model_name = "Qwen/Qwen3-8B"  # needs to be the same as the one in dataset_classes.py
# model_name = "Qwen/Qwen3-8B-Base"  # needs to be the same as the one in dataset_classes.py
# model_name = "allenai/OLMo-2-1124-7B-Instruct"  # needs to be the same as the one in dataset_classes.py
# model_name = "allenai/OLMo-2-1124-7B"  # needs to be the same as the one in dataset_classes.py
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, padding_side="left", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer.pad_token = tokenizer.eos_token

target_modules = {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj",
                  "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"}

# target_modules = {"mlp.down_proj"}
# layers = [10, 12, 14, 16, 18, 20]

current_labels_holder = {"labels": None, "indices": None}
save_dir = "llama_instruct_hidden_state_dumps_mean"
# save_dir = "llama_hidden_state_dumps_mean"
os.makedirs(save_dir, exist_ok=True)
save_interval = 25  # batches
save_counter = 0


def make_hook(name):
    def hook(module, input, output):
        with torch.no_grad():
            labels = current_labels_holder["labels"]
            indices = current_labels_holder["indices"]
            if labels is None:
                return

            avg = output.mean(dim=1).detach().cpu().numpy()
            hidden_state_data[name].extend([(label, vec.astype("float32"), idx) for label, vec, idx in zip(labels, avg, indices)])
            del output

    return hook


for i, block in enumerate(model.model.layers):
# for i, block in enumerate(model.language_model.layers):
    for subname, submodule in block.named_modules():
        fullname = f"layer_{i}.{subname}"
        # if i not in layers:
        #     continue
        if subname in target_modules:
            submodule.register_forward_hook(make_hook(fullname))


def save_hidden_states_h5(save_path, hidden_state_data):
    with h5py.File(save_path, "w") as f:
        for name, data in hidden_state_data.items():
            if not data:
                continue
            labels, vecs, indices = zip(*hidden_state_data[name])
            label_strs = [
                ",".join([str(l)]) if not isinstance(l, list) else ",".join(map(str, l))
                for l in labels
            ]
            labels = np.array(label_strs, dtype="S")
            # vecs = [v for v in vecs]
            vecs = np.stack([v for v in vecs])
            indices = np.array(indices, dtype="int32")
            grp = f.create_group(name)
            grp.create_dataset("labels", data=labels, compression="gzip")
            grp.create_dataset("vecs", data=vecs, compression="gzip")
            grp.create_dataset("indices", data=indices, compression="gzip")


go_emotions_ds = go_emotions_dataset(go_emotions_path, tokenizer)
synth_emotions_ds = synth_text_dataset(synth_dataset_path, tokenizer, cache_usage="new", N=2000, exclusion=(), synonyms=False)
hindi_ds = hindi_dataset(hindi_dataset_path, tokenizer)
semeval_emotions_ds = semeval_dataset(semeval_dataset_path, tokenizer)
twitter_emotions_ds = twitter_dataset("", tokenizer)
french_emotions_ds = french_dataset("", tokenizer)
emoevents_en_ds = emoevents_dataset(emoevents_en_dataset_path, tokenizer)
emoevents_es_ds = emoevents_dataset(emoevents_es_dataset_path, tokenizer)
german_plays_ds = german_plays_dataset(german_plays_dataset_path, tokenizer)
it_emo_ds = italian_dataset(italian_dataset_path, tokenizer)
math_ds = maths_dataset("", tokenizer)
datasets = {
    # "synth": synth_emotions_ds,
    # "semeval": semeval_emotions_ds, "german": german_plays_ds,
    #  "italian": it_emo_ds, "emoevent_es": emoevents_es_ds, "emoevent_en": emoevents_en_ds,
    #  "french": french_emotions_ds, "twitter": twitter_emotions_ds, "goemotions": go_emotions_ds,
    #  "hindi": hindi_ds
    "math": math_ds
}

# dataloader = DataLoader(go_emotions_ds, batch_size=256, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(synth_emotions_ds, batch_size=112, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(semeval_emotions_ds, batch_size=112, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(hindi_ds, batch_size=112, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(french_emotions_ds, batch_size=112, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(emoevents_es_ds, batch_size=112, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(german_plays_ds, batch_size=90, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
# dataloader = DataLoader(it_emo_ds, batch_size=112, shuffle=False, collate_fn=go_emotions_dataset.collate_fn, drop_last=False)
for ds_name, dataset in datasets.items():
    # if os.path.exists(f"{save_dir}/{ds_name}_dump_done.pkl"):
    #     continue
    dataloader = DataLoader(dataset, batch_size=60, shuffle=False, collate_fn=go_emotions_dataset.make_collate_fn(tokenizer.pad_token_id), drop_last=False)
    with tqdm(total=len(dataloader)) as pbar:
        for i, batch in enumerate(dataloader):
            current_labels_holder["labels"] = batch["labels"]
            if not isinstance(batch["indices"][0], int):
                batch["indices"] = [i.item() for i in batch["indices"]]

            current_labels_holder["indices"] = batch["indices"]
            del batch["labels"]
            del batch["indices"]
            if "text_token_mask" in batch:
                del batch["text_token_mask"]
            with torch.no_grad():
                _ = model(**batch)
            if (i + 1) % save_interval == 0:
                save_path = f"{save_dir}/{ds_name}_dump_{save_counter}.h5"
                save_hidden_states_h5(save_path, hidden_state_data)
                for v in hidden_state_data.values():
                    v.clear()
                gc.collect()
                torch.cuda.empty_cache()

                save_counter += 1
            pbar.update(1)

    if hidden_state_data:
        save_path = f"{save_dir}/{ds_name}_dump_{save_counter}.h5"
        save_hidden_states_h5(save_path, hidden_state_data)
    pkl.dump([""], open(f"{save_dir}/{ds_name}_dump_done.pkl", "wb"))
