import csv
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import pandas as pd
import os
import huggingface_hub
from torch.nn import functional as F
import pickle as pkl
from collections import defaultdict
import random
from datasets import load_dataset
import ast
import numpy as np
import json
from copy import deepcopy

go_emotions_path = "../../intralingua_emotion/datasets/human_emotion_datasets/go_emotions/raw/train-00000-of-00001.parquet"
go_emotions_path_val = "../../intralingua_emotion/datasets/human_emotion_datasets/go_emotions/simplified/validation-00000-of-00001.parquet"
go_emotions_path_test = "../../intralingua_emotion/datasets/human_emotion_datasets/go_emotions/simplified/test-00000-of-00001.parquet"
synth_dataset_path = "../../intralingua_emotion/datasets/emotions_dataset/singlesentences"
semeval_dataset_path = "interpretability/datasets/SemEval2007/affectivetext_annotated.csv"
maths_dataset_path = "interpretability/datasets/maths/sampled_qa_emotions.csv"
emoevents_en_dataset_path = "interpretability/datasets/EmoEvent/emotion_en.csv"
emoevents_es_dataset_path = "interpretability/datasets/EmoEvent/emotion_es.csv"
german_plays_dataset_path = "interpretability/datasets/german_plays/emotion_de.csv"
hindi_dataset_path = "interpretability/datasets/bhaav_dataset/hindi_dataset.csv"
italian_dataset_path = "interpretability/datasets/MultiEmotions-It/emotion_it.csv"


class text_dataset(Dataset):
    def __init__(self, path, tokenizer, prompt_template=""):
        super().__init__()
        self.dataset = path
        self.tokenizer = tokenizer
        self.prompt_template = prompt_template

    def __len__(self):
        return len(self.dataset["text"])

    def __getitem__(self, idx):
        text, label, index = self.dataset["text"][idx], self.dataset["label"][idx], self.dataset["original_index"][idx]
        if self.tokenizer.chat_template is not None and not self.prompt_template:
            text = self.tokenizer.apply_chat_template([{"role": "user", "content": text}], tokenize=False, add_generation_prompt=True, enable_thinking=False)
        tokenized_text = self.tokenizer(text, return_tensors="pt")
        if self.prompt_template:
            prompt = deepcopy(self.prompt_template)
            if "emotion_list" in prompt[0]["content"]:
                prompt[0]["content"] = prompt[0]["content"].format(emotion_list=str(self.emotion_list).replace("'", ""))
            prompt_tokenized = np.array(self.tokenizer.tokenize(self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False, enable_thinking=False)))
            tokenized_text_start = np.argwhere(prompt_tokenized == "Ġ{")[0][0] + 1
            text_len = tokenized_text["input_ids"].shape[1]
            tokenized_text_end = tokenized_text_start + text_len - 1
            prompt[1]["content"] = prompt[1]["content"].format(text=text)
            full_text = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, enable_thinking=False)
            full_tokenized_text = self.tokenizer(full_text, return_tensors="pt")
            text_token_mask = torch.zeros_like(full_tokenized_text["input_ids"])
            text_token_mask[0, tokenized_text_start:tokenized_text_end] = 1
        else:
            full_tokenized_text = tokenized_text
            text_token_mask = torch.ones_like(full_tokenized_text["input_ids"])
        return {
            "input_ids": full_tokenized_text["input_ids"],
            "attention_mask": full_tokenized_text["attention_mask"],
            "label": label,
            "text_token_mask": text_token_mask,
            "index": index
        }

    @staticmethod
    def make_collate_fn(pad_token_id=128009, max_seq_len=512):
        def collate_fn(batch):
            batch = [i for i in batch if i["input_ids"].shape[1] <= max_seq_len]
            if len(batch) == 0:
                return None  # or raise an error if that's preferable
            max_length = max([i["input_ids"].shape[1] for i in batch])
            input_ids = torch.vstack([F.pad(i["input_ids"], pad=(0, max_length - i["input_ids"].shape[1]), mode="constant", value=pad_token_id) for i in batch])
            attention_masks = torch.vstack([F.pad(i["attention_mask"], pad=(0, max_length - i["attention_mask"].shape[1]), mode="constant", value=0) for i in batch])
            if isinstance(batch[0]["label"], torch.Tensor):
                labels = torch.stack([i["label"] for i in batch])
            elif isinstance(batch[0]["label"], int):
                labels = torch.tensor([i["label"] for i in batch])
            else:
                labels = [i["label"] for i in batch]

            if "text_token_mask" in batch[0]:
                text_token_mask = torch.vstack([F.pad(i["text_token_mask"], pad=(0, max_length - i["text_token_mask"].shape[1]), mode="constant", value=0) for i in batch])
                return {
                    "input_ids": input_ids,
                    "attention_mask": attention_masks,
                    "text_token_mask": text_token_mask,
                    "labels": labels,
                    "indices": [i["index"] for i in batch]
                }
            return {
                    "input_ids": input_ids,
                    "attention_mask": attention_masks,
                    "labels": labels,
                    "indices": [i["index"] for i in batch]
                }
        return collate_fn



class go_emotions_dataset(text_dataset):

    def __init__(self, go_emotions_path, tokenizer, prompt_template="", split="train"):
        super().__init__(go_emotions_path, tokenizer, prompt_template)
        self.go_emotions_mapping = {
            0: "admiration",
            1: "amusement",
            2: "anger",
            3: "annoyance",
            4: "approval",
            5: "caring",
            6: "confusion",
            7: "curiosity",
            8: "desire",
            9: "disappointment",
            10: "disapproval",
            11: "disgust",
            12: "embarrassment",
            13: "excitement",
            14: "fear",
            15: "gratitude",
            16: "grief",
            17: "joy",
            18: "love",
            19: "nervousness",
            20: "optimism",
            21: "pride",
            22: "realization",
            23: "relief",
            24: "remorse",
            25: "sadness",
            26: "surprise",
            27: "neutral",
        }
        if split == "train":
            self.dataset = pd.read_parquet(go_emotions_path, engine='pyarrow')
            emotion_cols = self.dataset.columns[9:]  # assumes emotions start at column index 9
            self.dataset["label"] = self.dataset[emotion_cols].idxmax(axis=1)
            self.dataset["original_index"] = self.dataset.index

            temp_labels = self.remap_to_synth(self.dataset["label"])
            self.dataset["temp_label"] = temp_labels

            self.dataset = (
                self.dataset.groupby("temp_label")
                .apply(lambda x: x.sample(n=min(len(x), 250), random_state=42))
                .reset_index(drop=True)
            )

            self.dataset.drop(columns=["temp_label"], inplace=True)

            print(f"Dataset size: {len(self)}")
        else:
            self.dataset = pd.read_parquet(go_emotions_path, engine='pyarrow')
            self.dataset = self.dataset[self.dataset["labels"].apply(lambda x: len(x) == 1)].copy()
            self.dataset["labels"] = self.dataset["labels"].apply(lambda x: self.go_emotions_mapping[x[0]])
            self.dataset.rename(columns={"labels": "label"}, inplace=True)
            temp_labels = self.remap_to_synth(self.dataset["label"])
            self.dataset["temp_label"] = temp_labels
            self.dataset["original_index"] = self.dataset.index
            self.dataset = self.dataset.reset_index(drop=True)

        self.emotion_list = np.unique(self.dataset["label"]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "happy": {"joy", "amusement"}, # originally name joy, excitement originally in here
            "excitement": {"excitement"},
            "anger": {"anger", "annoyance", "disapproval"},
            "sad": {"sadness", "disappointment", "grief", "remorse"},
            "fear": {"fear", "nervousness"},
            "disgust": {"disgust"},
            "surprise": {"surprise"},
            # "love": {"love", "caring", "desire", "gratitude"},
            "neutral": {"neutral"},
        }

        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        return np.array([label_map[label] for label in labels])


class synth_text_dataset(text_dataset):
    def __init__(self, synth_dataset_path, tokenizer, exclusion=("sarcastic", "condescension"), N=5000, prompt_template="", cache_usage="cached", synonyms=True):
        super().__init__(synth_dataset_path, tokenizer, prompt_template)
        cache_path = "hidden_state_dumps_all_tokens/cached_synth_dataset_order.pkl"

        # Handle "cached" and preload for "compliment"
        if cache_usage in {"cached", "compliment"}:
            with open(cache_path, "rb") as f:
                self.dataset, self.emotion_list = pkl.load(f)

        # If compliment, build from non-cached items
        build_from_complement = cache_usage == "compliment"
        build_new = cache_usage in {"new", "new_save"}

        if build_from_complement or build_new:
            cached_set = set(zip(self.dataset["original_index"], self.dataset["label"])) if build_from_complement else set()

            emotion_dses = os.listdir(synth_dataset_path)
            all_emotion_passages = {
                fname.split("_")[-1][:-4]: pkl.load(open(os.path.join(synth_dataset_path, fname), "rb"))
                for fname in emotion_dses if fname.split("_")[-1][:-4] not in exclusion
            }

            grouped = defaultdict(list)
            for label, items in all_emotion_passages.items():
                entries = [
                    (ctx["text"], label, ctx["id"])
                    for item in items for ctx in item["ctxs"]
                    if not build_from_complement or (ctx["id"], label) not in cached_set
                ]
                random.shuffle(entries)
                grouped[label].extend(entries[:N])

            sampled = [x for entries in grouped.values() for x in entries]

            self.dataset = {
                "text": [x[0] for x in sampled],
                "label": [x[1] for x in sampled],
                "original_index": [x[2] for x in sampled]
            }
            self.emotion_list = np.unique(self.dataset["label"]).tolist()
            # import IPython; IPython.embed()

            if cache_usage == "new_save":
                with open(cache_path, "wb") as f:
                    pkl.dump([self.dataset, self.emotion_list], f)
        self.synonyms = synonyms
        if synonyms:
            self.emotion_synonyms = {
                "sad": {
                    "words": ["sad", "gloomy", "down", "blue", "depressed", "melancholy", "miserable", "heartbroken", "unhappy", "sorrowful", "sadness"],
                    "weights": [0.3, 0.09, 0.08, 0.08, 0.07, 0.06, 0.06, 0.05, 0.07, 0.06, 0.08]
                },
                "happy": {
                    "words": ["happy", "joyful", "cheerful", "content", "pleased", "delighted", "glad", "elated", "upbeat", "thrilled", "ecstatic", "overjoyed"],
                    "weights": [0.3, 0.09, 0.08, 0.08, 0.08, 0.07, 0.08, 0.06, 0.06, 0.05, 0.05, 0.05]
                },
                "neutral": {
                    "words": ["neutral", "calm", "indifferent", "unemotional", "blank", "detached", "even", "objective", "stoic"],
                    "weights": [0.3, 0.13, 0.11, 0.11, 0.08, 0.08, 0.06, 0.06, 0.07]
                },
                "fear": {
                    "words": ["fear", "afraid", "scared", "terrified", "frightened", "nervous", "anxious", "panicked", "alarmed", "petrified", "dread", "wary"],
                    "weights": [0.3, 0.1, 0.09, 0.08, 0.08, 0.07, 0.06, 0.05, 0.05, 0.04, 0.04, 0.04]
                },
                "envy": {
                    "words": ["envy", "jealous", "resentful", "covetous", "begrudge", "envious", "desire", "spiteful", "bitter", "yearning"],
                    "weights": [0.3, 0.14, 0.12, 0.1, 0.08, 0.07, 0.06, 0.05, 0.04, 0.04]
                },
                "anger": {
                    "words": ["anger", "mad", "furious", "irate", "annoyed", "outraged", "enraged", "frustrated", "resentful", "indignant", "rage", "fuming", "agitated"],
                    "weights": [0.3, 0.1, 0.1, 0.08, 0.08, 0.07, 0.06, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01]
                },
                "surprise": {
                    "words": ["surprise", "astonished", "amazed", "startled", "shocked", "stunned", "speechless", "aghast", "jarred", "bewildered"],
                    "weights": [0.3, 0.13, 0.12, 0.1, 0.09, 0.08, 0.06, 0.05, 0.04, 0.03]
                },
                "disgust": {
                    "words": ["disgust", "revolted", "repulsed", "sickened", "revolting", "nauseated", "distaste", "revulsion", "loathing", "abhorrence"],
                    "weights": [0.3, 0.13, 0.12, 0.1, 0.09, 0.08, 0.06, 0.05, 0.04, 0.03]
                },
                "excitement": {
                    "words": ["excitement", "thrilled", "ecstatic", "elated", "euphoric", "pumped", "delighted", "energized", "exhilarated", "overjoyed"],
                    "weights": [0.3, 0.13, 0.12, 0.1, 0.09, 0.08, 0.06, 0.05, 0.04, 0.03]
                }
        }

    def sample_emotion(self, emotion, skew=0.3):
        if random.random() > skew:
            return emotion  # canonical
        pool = self.emotion_synonyms[emotion]
        return random.choices(pool["words"], weights=pool["weights"], k=1)[0]

    @staticmethod
    def remap_to_synth(labels):
        return labels

    def __getitem__(self, idx):
        text, label, index = self.dataset["text"][idx], self.dataset["label"][idx], self.dataset["original_index"][idx]
        if self.tokenizer.chat_template is not None and not self.prompt_template:
            text = self.tokenizer.apply_chat_template([{"role": "user", "content": text}], tokenize=False, add_generation_prompt=True, enable_thinking=False)
        tokenized_text = self.tokenizer(text, return_tensors="pt")
        if self.prompt_template:
            prompt = deepcopy(self.prompt_template)
            prompt[0]["content"] = prompt[0]["content"].format(text=text, emotion_list=str(self.emotion_list).replace("'", ""))
            full_text = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, enable_thinking=False)
            full_tokenized_text = self.tokenizer(full_text, return_tensors="pt")
            if len(prompt) > 1 and prompt[1]["role"] == "assistant":
                emotion = prompt[1]["content"]
                if self.synonyms:
                    prompt[1]["content"] = random.choices(self.emotion_synonyms[emotion]["words"], weights=self.emotion_synonyms[emotion]["weights"], k=1)[0]
                text_wo_answer = self.tokenizer.apply_chat_template([prompt[0]], tokenize=False, add_generation_prompt=True, enable_thinking=False)
                tokenized_text_wo_answer = self.tokenizer(text_wo_answer, return_tensors="pt")
                return {
                    "input_ids": full_tokenized_text["input_ids"],
                    "attention_mask": full_tokenized_text["attention_mask"],
                    "text_wo_answers_id": tokenized_text_wo_answer["input_ids"],
                    "text_wo_answers_attention_mask": tokenized_text_wo_answer["attention_mask"],
                    "label": label,
                    "index": index
                }
        else:
            full_tokenized_text = tokenized_text
        return {
            "input_ids": full_tokenized_text["input_ids"],
            "attention_mask": full_tokenized_text["attention_mask"],
            "label": label,
            "index": index
        }

    @staticmethod
    def make_collate_fn(pad_token_id=128009, max_seq_len=512):
        def collate_fn(batch):
            max_length = max([i["input_ids"].shape[1] for i in batch])
            input_ids = torch.vstack([F.pad(i["input_ids"], pad=(0, max_length - i["input_ids"].shape[1]), mode="constant", value=pad_token_id) for i in batch])
            attention_masks = torch.vstack([F.pad(i["attention_mask"], pad=(0, max_length - i["attention_mask"].shape[1]), mode="constant", value=0) for i in batch])
            output_dict = {
                "input_ids": input_ids,
                "attention_mask": attention_masks,
                "labels": [i["label"] for i in batch],
                "indices": [i["index"] for i in batch]
            }
            if "text_wo_answers_id" in batch[0]:
                text_wo_answers_id = torch.vstack([F.pad(i["text_wo_answers_id"], pad=(0, max_length - i["text_wo_answers_id"].shape[1]), mode="constant", value=pad_token_id) for i in batch])
                text_wo_answers_attention_mask = torch.vstack([F.pad(i["text_wo_answers_attention_mask"], pad=(0, max_length - i["text_wo_answers_attention_mask"].shape[1]), mode="constant", value=0) for i in batch])
                output_dict["text_wo_answers_id"] = text_wo_answers_id
                output_dict["text_wo_answers_attention_mask"] = text_wo_answers_attention_mask
            return output_dict
        return collate_fn


class twitter_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template="", split="train"):
        super().__init__(emotions_path, tokenizer, prompt_template)
        dataset = load_dataset("dair-ai/emotion")
        label_map = {0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}
        self.dataset = pd.DataFrame(dataset[split])
        self.dataset["label"] = self.dataset["label"].map(label_map)
        self.dataset["original_index"] = self.dataset.index
        self.dataset = (
            self.dataset.groupby("label")
            .apply(lambda x: x.sample(n=min(len(x), 167), random_state=42))
            .reset_index(drop=True)
        )
        self.emotion_list = np.unique(self.dataset["label"]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "anger": {"anger"},
            "fear": {"fear"},
            "happy": {"joy"},
            "sad": {"sadness"},
            "surprise": {"surprise"},
        }

        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class emoevents_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template="", split="train"):
        super().__init__(emotions_path, tokenizer, prompt_template)
        if split != "train":
            emotions_path = emotions_path[:-4] + f"_{split}{emotions_path[-4:]}"
        self.dataset = pd.read_csv(emotions_path)
        self.dataset["original_index"] = self.dataset.index
        self.dataset = (
            self.dataset.groupby("label")
            .apply(lambda x: x.sample(n=min(len(x), 1000), random_state=42))
            .reset_index(drop=True)
        )
        self.dataset = self.dataset[self.dataset["label"] != "others"].reset_index(drop=True)
        self.emotion_list = np.unique(self.dataset["label"]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        random_labels = ["neutral", "happy", "sad", "fear", "disgust", "envy", "anger", "surprise", "excitement"]
        new_label_map = np.random.choice(random_labels, labels.shape, replace=True)

        # mapping = {
        #     "anger": {"anger"},
        #     "disgust": {"disgust"},
        #     "fear": {"fear"},
        #     "happy": {"joy"},
        #     "sad": {"sadness"},
        #     "surprise": {"surprise"},
        # }
        #
        # label_map = {label: "other" for label in np.unique(labels)}
        # for new_label, old_set in mapping.items():
        #     for old_label in old_set:
        #         label_map[old_label] = new_label
        #
        # new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class german_plays_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template=""):
        super().__init__(emotions_path, tokenizer, prompt_template)
        self.dataset = pd.read_csv(emotions_path)
        self.dataset["original_index"] = self.dataset.index

        self.dataset = (
            self.dataset.groupby("label")
            .apply(lambda x: x.sample(n=min(len(x), 1000), random_state=42))
            .reset_index(drop=True)
        )

        print(f"Dataset size: {len(self)}")
        self.emotion_list = np.unique(self.dataset["label"]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "anger": {"anger"},
            "disgust": {"disgust"},
            "envy": {"envy"},
            "excitement": {"excitement"},
            "fear": {"fear"},
            "happy": {"happy"},
            "sad": {"sad"},
        }

        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class hindi_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template=""):
        super().__init__(emotions_path, tokenizer, prompt_template)
        self.dataset = pd.read_csv(emotions_path, encoding="utf-8-sig")
        self.dataset["original_index"] = self.dataset.index

        self.dataset = (
            self.dataset.groupby("label")
            .apply(lambda x: x.sample(n=min(len(x), 100), random_state=42))
            .reset_index(drop=True)
        )

        print(f"Dataset size: {len(self)}")
        self.emotion_list = np.unique(self.dataset["label"]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "anger": {"Anger"},
            "happy": {"Joy"},
            "neutral": {"Neutral"},
            "sad": {"Sad"},
        }

        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class italian_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template=""):
        super().__init__(emotions_path, tokenizer, prompt_template)
        self.dataset = pd.read_csv(emotions_path)
        self.dataset["original_index"] = self.dataset.index

        self.dataset = (
            self.dataset.groupby("label")
            .apply(lambda x: x.sample(n=min(len(x), 1000), random_state=42))
            .reset_index(drop=True)
        )
        self.dataset["label"] = self.dataset["label"].apply(lambda x: ",".join(ast.literal_eval(x)) if isinstance(x, str) else str(x))
        print(f"Dataset size: {len(self)}")
        self.emotion_list = np.unique(self.dataset["label"]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "anger": {"anger"},
            "disgust": {"disgust"},
            "excitement": {"excitement"},
            "fear": {"fear"},
            "happy": {"happy"},
            "sad": {"sad"},
            "surprise": {"surprise"},
        }

        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class french_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template=""):
        super().__init__(emotions_path, tokenizer, prompt_template)
        dataset = load_dataset("TextToKids/EmoTextToKids-sentences")
        df = pd.DataFrame(dataset["train"])
        df = df[df["is_emotional"] == True].reset_index(drop=True)
        df = df[["target_sentence", "categories"]].copy()
        df["original_index"] = df.index
        df = df[df["categories"].map(len) == 1].copy()
        df["label"] = df["categories"].str[0]

        # Remove rare emotions
        label_counts = df["label"].value_counts()
        valid = label_counts[label_counts >= 50].index
        df = df[df["label"].isin(valid)].reset_index(drop=True)

        # Remap
        df["label"] = df["label"].replace({"sadness": "sad", "joy": "happy"})
        df = df[["target_sentence", "label"]].copy()
        df["original_index"] = df.index
        df = df[df["label"] != "other"].reset_index(drop=True)
        df = df.rename(columns={"target_sentence": "text"})
        self.dataset = df
        self.emotion_list = np.unique(self.dataset["label"]).tolist()
        self.dataset = (
            self.dataset.groupby("label")
            .apply(lambda x: x.sample(n=min(len(x), 185), random_state=42))
            .reset_index(drop=True)
        )


    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "anger": {"anger"},
            "fear": {"fear"},
            "happy": {"happy"},
            "sad": {"sad"},
            "surprise": {"surprise"},
        }

        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class semeval_dataset(text_dataset):
    def __init__(self, emotions_path, tokenizer, prompt_template=""):
        super().__init__(emotions_path, tokenizer, prompt_template)
        self.dataset = pd.read_csv(emotions_path)
        emotion_cols = ["anger", "disgust", "fear", "joy", "sadness", "surprise"]

        # Multi-label binarization
        self.dataset["label"] = self.dataset[emotion_cols].apply(lambda row: [emo for emo in emotion_cols if row[emo] >= 50], axis=1)

        # Fallback to single argmax if empty
        self.dataset["label"] = [
            labels if labels else [emotion_cols[row.values.argmax()]]
            for labels, (_, row) in zip(self.dataset["label"], self.dataset[emotion_cols].iterrows())
        ]

        self.dataset["original_index"] = self.dataset.index
        self.emotion_list = np.unique(self.dataset["label"]).tolist()

        if prompt_template:
            self.emotion_list = np.unique([j for i in self.emotion_list for j in i]).tolist()

    @staticmethod
    def remap_to_synth(labels):
        mapping = {
            "happy": {"joy"},
            "anger": {"anger"},
            "sad": {"sadness"},
            "fear": {"fear"},
            "disgust": {"disgust"},
            "surprise": {"surprise"},
        }


        if isinstance(labels[0], list):
            labels = [i[0] for i in labels]
        label_map = {label: "other" for label in np.unique(labels)}
        for new_label, old_set in mapping.items():
            for old_label in old_set:
                label_map[old_label] = new_label

        new_label_map = np.array([label_map[label] for label in labels])
        return new_label_map


class DynamicBatchSampler(Sampler):
    def __init__(self, lengths, max_tokens, shuffle=False):
        """
        lengths: list[int], token length of each sample in the dataset
        max_tokens: int, maximum number of tokens per batch (batch_size * max_seq_len)
        shuffle: bool, whether to shuffle samples each epoch
        """
        self.lengths = np.array(lengths)
        self.max_tokens = max_tokens
        self.shuffle = shuffle

    def __iter__(self):
        idxs = np.arange(len(self.lengths))
        if self.shuffle:
            np.random.shuffle(idxs)
        # Sort by length to minimize pad waste
        idxs = idxs[np.argsort(-self.lengths[idxs])]

        i = 0
        while i < len(idxs):
            max_len = self.lengths[idxs[i]]
            batch = [idxs[i]]
            i += 1
            # Greedily fill until max_tokens budget hit
            while i < len(idxs):
                L = max(max_len, self.lengths[idxs[i]])
                if (len(batch) + 1) * L > self.max_tokens:
                    break
                batch.append(idxs[i])
                max_len = L
                i += 1
            yield batch

    def __len__(self):
        # conservative upper bound: each sample alone
        return len(self.lengths)

