import os
import json
import scipy.io
import numpy as np
from PIL import Image
from torch.utils.data import Dataset


def read_json(fpath):
    with open(fpath, 'r') as f:
        obj = json.load(f)
    return obj


class Flowers102Dataset(Dataset):
    """
    setid['tstid']: 6149 images, setid['valid']: 1020 images, setid['trnid']: 1020 images
    For 16-shot learning and full-data settings, we re-split setid['tstid']+setid['valid'] as train samples, setid['trnid'] as test samples.
    """
    def __init__(self, root, split='train', transform=None):
        self.split = split
        self.transform = transform

        labels_path = os.path.join(root, 'imagelabels.mat')
        self.labels = scipy.io.loadmat(labels_path)['labels'][0]
        self.lab2cname_file = os.path.join(root, 'cat_to_name.json')
        self.classes = self.load_class_names(read_json(self.lab2cname_file))

        setid_path = os.path.join(root, 'setid.mat')
        setid = scipy.io.loadmat(setid_path)

        if split == 'train':
            self.image_ids = np.concatenate((setid['tstid'][0], setid['valid'][0])) - 1
        elif split == 'test':
            self.image_ids = setid['trnid'][0] - 1
        else:
            raise ValueError("Split must be 'train' or 'test'")

        self.data = [os.path.join(root, 'jpg', f'image_{i + 1:05d}.jpg') for i in self.image_ids]
        self.targets = [int(self.labels[i]) - 1 for i in self.image_ids]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = self.targets[idx]

        return image, label

    def load_class_names(self, class_dict):
        classname = [None] * len(class_dict)
        for class_number, class_name in class_dict.items():
            idx = int(class_number) - 1
            classname[idx] = class_name

        return classname
