import json

from PIL import Image
from torch.utils.data import Dataset
import os
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
                                    ToTensor)

class sun397_test_b32(Dataset):
    def __init__(self, dataset_path,transform):
        json_path = os.path.join(dataset_path, "test.json")
        self.classnames, tmp_image_paths = self.read_the_json(json_path)
        self.sample_dict = {}

        tmp_path = os.path.join(dataset_path, "SUN397")
        for index,a_class_paths in enumerate(tmp_image_paths):
            for a_path in a_class_paths:
                file_path = os.path.join(tmp_path,a_path)
                self.sample_dict[file_path] = index
        self.paths = list(self.sample_dict.keys())
        self.preprocess =transform

    def read_the_json(self,jsonpath):
        with open(jsonpath, 'r') as file:
            data = json.load(file)
        tags_list = []
        file_paths_list = []
        for key, value in data.items():
            classname = key.replace("_", " ").split("/")[0]
            tags_list.append(classname)
            file_paths_list.append(value)
        return tags_list, file_paths_list

    def _transform_test(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            lambda image: image.convert("RGB"),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.preprocess(image)
        return image, label



class sun397_test_rn50(Dataset):
    def __init__(self, dataset_path ,transform):
        json_path = os.path.join(dataset_path, "test.json")

        self.classnames,tmp_image_paths = self.read_the_json(json_path)
        self.sample_dict = {}

        tmp_path = os.path.join(dataset_path, "SUN397")
        for index,a_class_paths in enumerate(tmp_image_paths):
            for a_path in a_class_paths:
                file_path = os.path.join(tmp_path,a_path)
                self.sample_dict[file_path] = index
        self.paths = list(self.sample_dict.keys())
        self.transform = transform

    def read_the_json(self,jsonpath):
        with open(jsonpath, 'r') as file:
            data = json.load(file)
        tags_list = []
        file_paths_list = []
        for key, value in data.items():
            names = key.split("/")
            names = names[::-1]
            classname = " ".join(names)
            tags_list.append(classname)
            file_paths_list.append(value)

        return tags_list,file_paths_list

    def __len__(self):
        return len(self.sample_dict)

    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        label = self.sample_dict[self.paths[index]]
        image = self.transform(image)
        return image, label