from torchvision.datasets import ImageFolder
from src.eval.eval_tasks.utils.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
from torch.utils.data import Dataset
import json
import os
from PIL import Image

class ClassificationDataset(Dataset):
    def __init__(self, image_dir_path, annotations_path):
        self.image_dir_path = image_dir_path
        with open(annotations_path, "r") as f:
            self.annotations = [json.loads(line) for line in f]

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1])
        image = Image.open(img_path)
        image.load()
        
        # Create a new blank image for debug
        # image = Image.new("RGB", (256, 256), (255, 255, 255))
        return {
            "id": annotation["id"],
            "image": image,
            "ocr": annotation["text"],
            "class_name": "yes" if annotation["label"] == 1 else "no",
            "class_id": annotation["label"],
        }