from PIL import Image
from torch.utils.data import Dataset
import os

class DTD_test_dataset(Dataset):
    def __init__(self, dataset_path ,transform):
        test_label_path = os.path.join(dataset_path, "labels/test1.txt")
        self.root = dataset_path
        self.classnames, self.sample_dict = self.read_the_label_txt(test_label_path)
        self.paths = list(self.sample_dict.keys())
        self.transform = transform

    def read_the_label_txt(self,txt_path):
        data_list = []
        with open(txt_path, "r") as file:
            for line in file:
                data_list.append(line.strip())
        labels = []
        path_label_dict = {}
        images_path = os.path.join(self.root, "images")
        for a_data in data_list:
            label_name = a_data.split("/")[0]
            if label_name not in labels:
                labels.append(label_name)
            index = labels.index(label_name)
            img_path = os.path.join(images_path,a_data)
            path_label_dict[img_path] = index
        return labels, path_label_dict

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.transform(image)
        return image, label