import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset


class CustomDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        return sentence, label


def collator_fn(batch):
    return [x[0] for x in batch], torch.tensor([x[1] for x in batch])


def load_data_hf(config):
    dataset_name=config.dataset_name
    train_batch_size=config.train_batch_size
    eval_batch_size=config.eval_batch_size
    torch.manual_seed(config.torch_seed)
    if dataset_name == "sst2":
        dataset = load_dataset("glue", "sst2")
        train_set = CustomDataset(sentences=dataset["train"]["sentence"], labels=dataset["train"]["label"])
        test_set = CustomDataset(sentences=dataset["validation"]["sentence"], labels=dataset["validation"]["label"])
    elif dataset_name == "mnli_matched":
        dataset = load_dataset("glue", "mnli")
        train_set = CustomDataset(sentences=[[i, j] for i, j in zip(dataset["train"]["premise"], dataset["train"]["hypothesis"])],
                                  labels=dataset["train"]["label"])
        test_set = CustomDataset(sentences=[[i, j] for i, j in zip(dataset["validation_matched"]["premise"], dataset["validation_matched"]["hypothesis"])],
                                 labels=dataset["validation_matched"]["label"])
    elif dataset_name == "mnli_mismatched":
        dataset = load_dataset("glue", "mnli")
        train_set = CustomDataset(sentences=[[i, j] for i, j in zip(dataset["train"]["premise"], dataset["train"]["hypothesis"])],
                                  labels=dataset["train"]["label"])
        test_set = CustomDataset(sentences=[[i, j] for i, j in zip(dataset["validation_mismatched"]["premise"], dataset["validation_mismatched"]["hypothesis"])],
                                 labels=dataset["validation_matched"]["label"])
    else:
        raise ValueError("The dataset is not available. Check data/torch_data.py.")
    train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, collate_fn=collator_fn)
    test_loader = DataLoader(test_set, batch_size=eval_batch_size, shuffle=True, collate_fn=collator_fn)
    return train_loader, test_loader
