# https://huggingface.co/transformers/v3.2.0/custom_datasets.html
import os
from pathlib import Path
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
import torch
from sklearn.model_selection import StratifiedShuffleSplit
import copy
import numpy as np
from datasets import load_dataset

# def read_imdb_split(split_dir):
#     split_dir = Path(split_dir)
#     texts = []
#     labels = []
#     for label_dir in ["pos", "neg"]:
#         for text_file in (split_dir/label_dir).iterdir():
#             texts.append(text_file.read_text())
#             labels.append(0 if label_dir == "neg" else 1)

#     return texts, labels

class EvalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels, classification_only=True):
        self.encodings = encodings
        self.labels = labels
        self.classification_only = classification_only
        if self.classification_only:
            self.data = np.array([{key: torch.tensor(val[idx]) for key, val in self.encodings.items()} \
                for idx in range(len(self.labels))])
        else:
            self.data = np.array([{key: torch.tensor([val[idx], val[idx]])for key, val in self.encodings.items()} \
                for idx in range(len(self.labels))])

    def __getitem__(self, idx):
        item = self.data[idx]
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

def get_encodings(name):
    os.makedirs(f'data/{name}', exist_ok=True)
    save_path = f'data/{name}/encodings.pt'
    if os.path.exists(save_path):
        data = torch.load(save_path)
        return data["train_encodings"], data["train_labels"], data["test_encodings"], data["test_labels"],\
            max(max(data["train_labels"]), max(data["test_labels"]))+1

    # train_texts, train_labels = read_imdb_split('data/aclImdb/train')
    # test_texts, test_labels = read_imdb_split('data/aclImdb/test')

    # imdb, ag_news
    if name == "tweet_eval":
        dataset = load_dataset(name, "emoji")
    else:
        dataset = load_dataset(name)

    train_texts, train_labels = dataset["train"]["text"], dataset["train"]["label"]
    test_texts, test_labels = dataset["test"]["text"], dataset["test"]["label"]

    # train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    # val_encodings = tokenizer(val_texts, truncation=True, padding=True)
    test_encodings = tokenizer(test_texts, truncation=True, padding=True)

    torch.save({
        "train_encodings": train_encodings,
        "train_labels": train_labels,
        "test_encodings": test_encodings,
        "test_labels": test_labels,
    }, save_path)
    return train_encodings, train_labels, test_encodings, test_labels, \
        max(max(train_labels), max(test_labels))+1

def subsample(dataset, n, get_reg=False):
    if n <= 0 or n >= len(dataset):
        return dataset
    sss = StratifiedShuffleSplit(n_splits=1, test_size=len(dataset)-n, train_size=n, random_state=0)
    train_index, test_index = next(iter(sss.split(np.zeros(len(dataset.labels)), dataset.labels)))
    if get_reg:
        reg_dataset = copy.deepcopy(dataset)
        reg_dataset.data = reg_dataset.data[test_index]
        reg_dataset.labels = np.array(reg_dataset.labels)[test_index]
    dataset.data = dataset.data[train_index]
    dataset.labels = np.array(dataset.labels)[train_index]
    if get_reg:
        return dataset, reg_dataset
    else:
        return dataset

if __name__ == "__main__":
    pass