from PIL import Image
from torch.utils.data import Dataset
import os
import json
from scipy.io import loadmat
from torchvision.io import read_image
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

class flowers_test_dataset(Dataset):
    def __init__(self, dataset_path ,transform):
        labelno_mat_path = os.path.join(dataset_path, ".mat")
        split_mat_path = os.path.join(dataset_path, ".mat")
        label_name_json = os.path.join(dataset_path, ".json")

        self.root =dataset_path
        self.classnames, self.sample_dict = self.read_the_mat(labelno_mat_path,split_mat_path,label_name_json)
        self.paths = list(self.sample_dict.keys())
        self.transform = transform

    def read_the_mat(self, label_mat_path,split_mat_path,label_name_json):
        labels_mat= loadmat(label_mat_path)["labels"].tolist()[0]
        split_test_mat = loadmat(split_mat_path)["tstid"].tolist()[0]

        with open(label_name_json, 'r') as file:
            label_data = json.load(file)
        tmp_list = []

        for k,v in label_data.items():
            tmp_list.append([int(k)-1, v])
        sorted_label = sorted(tmp_list, key=lambda x: x[0])
        lables = [x[1] for x in sorted_label]

        path_label_dict = {}
        images_path = os.path.join(self.root, "jpg")
        for a_img in split_test_mat:
            img_path = "image_"+str(a_img).zfill(5)+".jpg"
            img_path = os.path.join(images_path, img_path)
            lable_no = labels_mat[int(a_img)-1]-1
            path_label_dict[img_path] = int(lable_no)
        return lables, path_label_dict

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.transform(image)
        return image, label
