from torch.utils.data import Dataset
import pandas as pd


class PromptAnswerDataset(Dataset):
    def __init__(self, config):
        self.data_directory = config['data_directory']
        self.dataset = pd.read_csv(self.data_directory)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        return self.dataset.prompt[idx], self.dataset.answer[idx]
