import os
from PIL import Image
from torch.utils.data import Dataset


class Food101Dataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = root
        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, 'images')
        self.meta_dir = os.path.join(root, 'meta')

        self.classes = sorted(os.listdir(self.image_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.data, self.targets = self._load_metadata()

    def _load_metadata(self):
        split_file = os.path.join(self.meta_dir, f'{self.split}.txt')
        data, targets = [], []

        with open(split_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            path = line.strip()
            cls_name = path.split('/')[0]
            img_path = os.path.join(self.image_dir, f"{path}.jpg")

            if os.path.exists(img_path):
                data.append(img_path)
                targets.append(self.class_to_idx[cls_name])

        return data, targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data[index]
        label = self.targets[index]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label
