import torch
import os
import PIL


class MYEXTDATASET(torch.utils.data.Dataset):
    def __init__(self, class_to_idx, data_root, max_iter, sampling_num, transforms):
        self.data_root = data_root
        self.transforms = transforms
        self.snum = sampling_num
        self.max_iter = max_iter
        self.class_to_idx = class_to_idx
        self.imgs = []
        self.labels = []
        i = 1
        while 1:
            for it in range(max_iter):
                iter_root = os.path.join(self.data_root, f"{it}")
                classes = os.listdir(iter_root)
                for iclass in classes:
                    prompts_list = os.listdir(os.path.join(iter_root, iclass))
                    for im in prompts_list:
                        if not os.path.exists(os.path.join(iter_root, iclass, im, f"00000{i}.jpg")):
                            print(f"{os.path.join(iter_root, iclass, im, f'00000{i}.jpg')} does not exist.")
                            continue
                        self.imgs.append(os.path.join(iter_root, iclass, im, f"00000{i}.jpg"))
                        self.labels.append(self.class_to_idx[iclass])
                        if len(self.imgs) >= self.snum * max_iter:
                            return
            i += 1

    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        item = PIL.Image.open(self.imgs[index])
        img = self.transforms(item)
        target = self.labels[index]
        return img, target
    