import json

from PIL import Image
from torch.utils.data import Dataset
import os
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

class FGVCAircraft_test(Dataset):
    def __init__(self, dataset_path,transform):

        test_file_path = os.path.join(dataset_path, "test")
        self.classnames = os.listdir(test_file_path)
        tmp_image_paths = []
        for a_dir_name in self.classnames:
            tmp_path = os.path.join(test_file_path, a_dir_name)
            img_paths = os.listdir(tmp_path)
            img_paths = [os.path.join(tmp_path,x) for x in img_paths]
            tmp_image_paths.append(img_paths)

        self.sample_dict = {}

        for index, a_class_paths in enumerate(tmp_image_paths):
            for file_patha_path in a_class_paths:
                self.sample_dict[file_patha_path] = index

        self.paths = list(self.sample_dict.keys())
        self.preprocess = transform

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.preprocess(image)
        return image, label

