from torch.utils.data import Dataset
from tokenizer import NumberTokenizer
from numbers_class import NumberBasic
from task import Task

class NumberDataset(Dataset):
    def __init__(self, data: list[list[NumberBasic]], tokenizer: NumberTokenizer, task: Task, training: bool = True, trunc: int | None = None, return_numbers: bool = False):
        self.tokenizer = tokenizer
        self.task = task
        self.data = data
        self.training = training
        if trunc is not None:
            self.data = self.data[:trunc]
        self.return_numbers = return_numbers
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        pair = self.data[idx]
        a, b = self.task.preprocess_data(pair)
        input_numbers = (a, b) if b is not None else (a,)
        if not self.return_numbers:
            return self.tokenizer.encode_sample(input_numbers, contain_answer=self.training, return_tensor='pt')
        else:
            return {
                "tokens": self.tokenizer.encode_sample(input_numbers, contain_answer=self.training, return_tensor='pt'),
                "numbers": input_numbers[0] if len(input_numbers) == 1 else input_numbers
            }
    # def __iter__(self):
    #     # give a faster way to iterate the dataset, directly iter the data instead of using __getitem__
    #     for pair in self.data:
    #         a, b = self.task.preprocess_data(pair)
    #         input_numbers = (a, b) if b is not None else (a,)
    #         yield self.tokenizer.encode_sample(input_numbers, contain_answer=self.training, return_tensor='pt')