# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
from pathlib import Path
import transformers
from .collators import DictCollatorWithPadding
import torch
import pandas as pd
from collections import OrderedDict
import json
import pickle as pkl


class MultiEmotionDataset(torch.utils.data.Dataset):
    """
    Implements a loader for the Jigsaw toxicity dataset.
    To get the files download from the following URL into `path`:
    https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data
    """

    LABEL_MAP = OrderedDict([
        ("neutral", 0),
        ("anger", 1),
        ("condescension", 2),
        ("disgust", 3),
        ("envy", 4),
        ("excitement", 5),
        ("fear", 6),
        ("happy", 7),
        ("humor", 8),
        ("sad", 9),
        ("sarcastic", 10),
        ("surprise", 11)
    ])

    LABEL_NAMES = [
        "neutral",
        "anger",
        "condescension",
        "disgust",
        "envy",
        "excitement",
        "fear",
        "happy",
        "humor",
        "sad",
        "sarcastic",
        "surprise"
    ]

    def __init__(
        self, path: Path, split: str, tokenizer: transformers.PreTrainedTokenizer
    ) -> torch.utils.data.Dataset:
        self.split = split
        self.path = "/".join(str(path).split("/")[:-1])
        self.tokenizer = tokenizer

        # if self.split == "train":
        self.data = []
        for label in self.LABEL_NAMES:
            if label == self.split:
                train_data = pkl.load(open(f"{self.path}/emotions_dataset/gpl_retrieval_results_nq_train_{label}.pkl", "rb"))
                [i.update({"id": f"{label}_{idx}"}) for idx, i in enumerate(train_data)]
                self.data.extend(train_data)
        # elif self.split == "test":
        #     test_data = pd.read_csv(path / "test.csv", index_col="id")
        #     test_labels = pd.read_csv(path / "test_labels.csv", index_col="id")
        #     test_dataset = pd.concat(
        #         [test_data, test_labels], axis=1, ignore_index=False
        #     )
        #     # test dataset comes with unannotated data (label=-1)
        #     test_dataset = test_dataset.loc[
        #         (test_dataset[test_dataset.columns[1:]] > -1).all(axis=1)
        #     ]
        #     self.data = self._preprocess(test_dataset)

        _ = self.data[0]  # small test
        self.index = torch.arange(len(self.data))

    def set_label(self, label: str) -> None:
        index = torch.arange(len(self.data))
        labels = torch.asarray([label in d["id"] for d in self.data])
        self.index = index[labels]

    def __getitem__(self, item):
        # import ipdb; ipdb.set_trace()
        datum = self.data[int(self.index[item])]
        datum["input_text"] = (f"Classify the following passage as one of the following emotions: [{', '.join(self.LABEL_NAMES)}] The answer should only contain the emotion.\n" +
                               f"Passage: {datum['ctxs'][0]['text']}" + f"\nEmotion: ")
        # datum["input_text"] = ("Write a high-quality answer for the given question using only your knowledge of the question and the provided search results (some of which might be irrelevant). The answer should only contain 1-3 words.\n" +
        #                        f"Document [1] (Title:{datum['ctxs'][0]['title']}) {datum['ctxs'][0]['text']}" + f"\n\nQuestion: {datum['question']}\nAnswer: ")
        tokens = self.tokenizer(datum["input_text"], truncation=True, padding=False)
        datum.update(tokens)
        return datum

    def __len__(self):
        return len(self.index)


def get_multiemotion_dataset(
    path: Path, split: str, tokenizer: transformers.PreTrainedTokenizer
) -> torch.utils.data.Dataset:
    return MultiEmotionDataset(path, split, tokenizer), DictCollatorWithPadding(tokenizer)
