import os
import sys

from torch.utils.data import DataLoader, Dataset
from typing import List, Dict


class JsonDataset(Dataset):
    
    def __init__(self, data, tokenizer) -> None:
        """
        Args:
            data (List): A list of prompt;
        """
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        conversation = self.tokenizer.apply_chat_template(
            [{"role": "user", "content": self.data[index]}],
            tokenize=False,
            add_generation_prompt=True
        )
        
        return conversation



def load_data(promptList:List, tokenizer, batch_size:int) -> DataLoader:
    
    # generate index for each data
    dataDict = {}
    
    for idx, item in enumerate(promptList):
        if dataDict.get(str(idx)) is None:
            dataDict[str(idx)] = item
        else:
            raise ValueError("Repeat data index.")
    
    processed_dataset = JsonDataset(data=promptList, tokenizer=tokenizer)
    dataloader = DataLoader(dataset=processed_dataset, batch_size=batch_size)
    
    return dataloader, dataDict
