from PIL import Image
from torch.utils.data import Dataset
import os
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

class Standfordcars_test_rn50(Dataset):
    def __init__(self, dataset_path,transform):
        path = os.path.join(dataset_path,"test")
        self.classnames=[]
        tmp_list = os.listdir(path)
        for a_name in tmp_list:
            names = a_name.split(" ")
            year = names.pop(-1)
            names.insert(0, year)
            tmp_classnames = " ".join(names)
            self.classnames.append(tmp_classnames)

        self.transform = transform

        self.sample_dict = {}
        for index, a_class in enumerate(os.listdir(path)):
            tmp_path = os.path.join(path,a_class)
            imgs_paths = os.listdir(tmp_path)
            for a_img_name in imgs_paths:
                tmp_path_img = os.path.join(tmp_path,a_img_name)
                self.sample_dict[tmp_path_img] = index
        self.paths = list(self.sample_dict.keys())

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.transform(image)
        return image, label


class Standfordcars_test_b32(Dataset):
    def __init__(self, dataset_path):
        path = os.path.join(dataset_path,"test")
        self.classnames=[]
        tmp_list = os.listdir(path)

        for a_name in tmp_list:
            self.classnames.append(a_name)

        self.sample_dict = {}
        for index, a_class in enumerate(os.listdir(path)):
            tmp_path = os.path.join(path,a_class)
            imgs_paths = os.listdir(tmp_path)
            for a_img_name in imgs_paths:
                tmp_path_img = os.path.join(tmp_path,a_img_name)
                self.sample_dict[tmp_path_img] = index
        self.paths = list(self.sample_dict.keys())

        self.preprocess = self._transform_test(224)

    def _transform_test(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            lambda image: image.convert("RGB"),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.preprocess(image)
        return image, label
