from PIL import Image
from torch.utils.data import Dataset
import os
import json
from torchvision.io import read_image
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

class food_test_dataset(Dataset):
    def __init__(self, dataset_path ,transform):
        split_txt_path = os.path.join(dataset_path, ".txt")
        self.root =dataset_path
        self.classnames, self.sample_dict = self.read_the_label_txt(split_txt_path)
        self.paths = list(self.sample_dict.keys())
        self.transform = transform

    def read_the_label_txt(self,txt_path):
        data_list = []
        with open(txt_path, "r") as file:
            for line in file:
                filename = line.strip()+".jpg"
                data_list.append(filename)
        labels = []
        path_label_dict = {}
        images_path = os.path.join(self.root, "images")
        for a_data in data_list:
            label_name = a_data.split("/")[0]
            if label_name not in labels:
                labels.append(label_name)
            index = labels.index(label_name)
            img_path = os.path.join(images_path,a_data)
            path_label_dict[img_path] = index
        return labels, path_label_dict

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.transform(image)
        return image, label
