import torch
from torch.utils.data import Dataset

class TextDataset(Dataset):    
    def __init__(self, training_data, tokenizer, max_length=32):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        
        for item in training_data:
            prompt = item["text"]
            answer = item["object"]
            
            if prompt.endswith(" "):
                full_text = f"{prompt}{answer}."
            else:
                full_text = f"{prompt} {answer}."
            
            encoded = tokenizer(
                full_text,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt"
            )
            
            input_ids = encoded["input_ids"].squeeze(0)
            attention_mask = encoded["attention_mask"].squeeze(0)
            
            labels = input_ids.clone()
            
            self.examples.append({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels
            })
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.examples[idx]["input_ids"],
            "attention_mask": self.examples[idx]["attention_mask"],
            "labels": self.examples[idx]["labels"]
        }