import json

from PIL import Image
from torch.utils.data import Dataset
import os
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

class eurosat_test(Dataset):
    def __init__(self, dataset_path, transform):
        json_path = os.path.join(dataset_path, "EuroSAT.json")
        self.classnames, tmp_image_label_dict = self.read_the_json(json_path)
        self.sample_dict = {}
        tmp_path = os.path.join(dataset_path, "2750")


        for k,v in tmp_image_label_dict.items():
            new_path = os.path.join(tmp_path,k)
            self.sample_dict[new_path] = v
        self.paths = list(self.sample_dict.keys())
        self.transform = transform

    def read_the_json(self,jsonpath):
        with open(jsonpath, 'r') as file:
            data = json.load(file)
        test_set = data["test"]
        tags_list = []
        tmp_sample_dict = {}
        for a_data_list in test_set:
            tmp_sample_dict[a_data_list[0]] = int(a_data_list[1])
            if a_data_list[2] not in tags_list:
                tags_list.append(a_data_list[2] )
        return tags_list, tmp_sample_dict

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.transform(image)
        return image, label

