import os
import torch
import h5py
from tqdm import tqdm
from collections import defaultdict
import numpy as np

def load_hidden_states_h5(path):
    with h5py.File(path, "r") as f:
        return {
            name: list(zip(
                [label.decode() for label in grp["labels"][:]],
                grp["vecs"][:],
                grp["indices"][:]
            ))
            for name, grp in tqdm(f.items(), desc="Groups")
        }

def merge_hidden_states(h1, h2):
    merged = defaultdict(list)
    for k in set(h1) | set(h2):
        merged[k] = h1.get(k, []) + h2.get(k, [])
    return dict(merged)

def restructure_hidden_states(hidden_states):
    nested = defaultdict(lambda: defaultdict(dict))
    for k, v in hidden_states.items():
        parts = k.split(".")
        layer = int(parts[0].split("_")[1])
        block = parts[1]
        proj = parts[2]
        nested[layer][block][proj] = v
    return nested

def flatten_and_save_by_proj(nested_hidden_states, save_dir, prefix):
    os.makedirs(save_dir, exist_ok=True)
    for layer, blocks in tqdm(nested_hidden_states.items(), desc="Layers"):
        for block, projs in blocks.items():
            for proj, data in projs.items():
                labels, vecs, indices = zip(*data)
                fpath = os.path.join(save_dir, f"{prefix}_layer_{layer}_{block}_{proj}.h5")
                with h5py.File(fpath, "w") as f:
                    f.create_dataset("labels", data=np.array(labels, dtype="S"), compression="gzip")
                    f.create_dataset("vecs", data=np.stack(vecs), compression="gzip")
                    f.create_dataset("indices", data=np.array(indices, dtype="int32"), compression="gzip")


valid_blocks = [
    ["mlp", "up_proj"],
    ["mlp", "down_proj"],
    ["mlp", "gate_proj"],
    ["self_attn", "q_proj"],
    ["self_attn", "k_proj"],
    ["self_attn", "v_proj"],
    ["self_attn", "o_proj"],
]

def load_layer_data(layer_idx, block, proj, act_dir, mani_dir):
    act_path = os.path.join(act_dir, f"layer_{layer_idx}_{block}_{proj}.h5")
    mani_path = os.path.join(mani_dir, f"layer_{layer_idx}_{block}_{proj}_manifold.h5")

    if not (os.path.exists(act_path) and os.path.exists(mani_path)):
        return None

    with h5py.File(act_path, "r") as f:
        labels = [l.decode() for l in f["labels"][:]]
        vecs = f["vecs"][:]
        indices = f["indices"][:]

    with h5py.File(mani_path, "r") as f:
        manifold = {k: f[k][:] for k in ["U", "S", "Vh", "mean"]}

    return labels, vecs, indices, manifold


# prefixes = ["synth", "hindi", "semeval", "german", "italian", "emoevent_es", "emoevent_en", "french", "twitter", "goemotions"]
prefixes = ["math"]
target_path = "llama_instruct_hidden_state_dumps_mean"
rel_files = [i for i in os.listdir(target_path) if i[-3:] == ".h5"]

for prefix in tqdm(prefixes, desc="Dataset Progress"):
    cur_rel_files = [i for i in rel_files if prefix in i]
    all_curr_hidden_states = []
    for cur_rel_file in cur_rel_files:
        all_curr_hidden_states.append(load_hidden_states_h5(f"{target_path}/{cur_rel_file}"))
    all_hidden_states = all_curr_hidden_states[0]
    for hidden_state in all_curr_hidden_states[1:]:
        all_hidden_states = merge_hidden_states(all_hidden_states, hidden_state)
    nested_hidden_states = restructure_hidden_states(all_hidden_states)
    flatten_and_save_by_proj(nested_hidden_states, target_path, prefix)



