import torch
import gzip
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
import scipy.io
import pickle
import h5py


class ImageNetDatasetSmall(Dataset):
    def __init__(self, root_dir, split='train', transform=None, debug=False, 
                 need_cls=False, data_size=-1):
        self.split = split
        self.root_dir = root_dir 
        self.data_dir = os.path.join(self.root_dir, self.split)
        self.transform = transform
        self.debug = debug
        self.data_size = data_size

        self.images = []
        self.labels = []
        # use only two classes
        self.class_idxs = [235, 696] # 235]
        self.class_idxs = [str(x) for x in self.class_idxs]

        print(f'Loading images {split} ... ')
        # for dirname in tqdm(os.listdir(self.data_dir)):
        for dirname in self.class_idxs:
            count = 0
            for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                image = os.path.join(self.data_dir, dirname, filename)
                # print('image', image)
                if not debug and self.transform:
                    image_obj = Image.open(image)
                    image = self.transform(image_obj)
                self.images.append(image)
                self.labels.append(self.class_idxs.index(dirname))
                if self.data_size != -1:
                    count += 1
                    if count == self.data_size:
                        break
        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        # import pdb
        # pdb.set_trace()
        return image, label
    

class MaskedImageNetDatasetSmall(Dataset):
    def __init__(self, root_dir, mask_dir, split='train', transform=None, 
                 mask_transform=None, debug=False, need_cls=False, data_size=-1):
        self.split = split
        self.root_dir = root_dir
        self.mask_dir = mask_dir
        # mask_dir has to have the same structure as root_dir, each image adding '.pkl' in name
        self.data_dir = os.path.join(self.root_dir, self.split)
        self.mask_dir_split = os.path.join(self.mask_dir, self.split)
        self.transform = transform
        self.mask_transform = mask_transform
        self.debug = debug
        self.data_size = data_size

        self.images = []
        self.masks = []
        self.labels = []
        self.class_idxs = [235, 696] #,  349,  35,  696,  739,  802,  949]
        self.class_idxs = [str(x) for x in self.class_idxs]

        print(f'Loading images {split} ... ')
        # for dirname in tqdm(os.listdir(self.data_dir)):
        for dirname in self.class_idxs:
            count = 0
            for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                image = os.path.join(self.data_dir, dirname, filename)
                # print('image', image)
                if not debug and self.transform:
                    image_obj = Image.open(image)
                    image = self.transform(image_obj)
                self.images.append(image)
                
                mask_file_path = os.path.join(self.mask_dir_split, dirname, filename + '.pkl')
                if not debug and self.mask_transform:
                    with open(mask_file_path, 'rb') as input_file:
                        masks_i = pickle.load(input_file)
                    mask = torch.stack([torch.tensor(mask['segmentation'])
                                        for mask in masks_i])
                    mask = self.mask_transform(mask)  # can take top k
                    self.masks.append(mask)
                else:
                    self.masks.append(mask_file_path)
                
                self.labels.append(self.class_idxs.index(dirname))
                
                count += 1
                if count == self.data_size:
                    break
        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        mask = self.masks[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        if self.debug and self.mask_transform:
            mask_file_path = mask
            with open(mask_file_path, 'rb') as input_file:
                masks_i = pickle.load(input_file)
            mask = torch.stack([torch.tensor(mask['segmentation'])
                                for mask in masks_i])
            mask = self.mask_transform(mask)  # can take top k
        # import pdb
        # pdb.set_trace()
        return image, label, mask


class ImageNetDatasetMedium(Dataset):
    def __init__(self, root_dir, split='train', transform=None, debug=False, 
                 need_cls=False, data_size=-1, label2id=None):
        need_cls = True
        self.split = split
        self.root_dir = root_dir 
        if split == 'train':
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_train')
        else: # val
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_val')
        self.transform = transform
        self.debug = debug
        self.data_size = data_size
        self.label2id = label2id

        self.images = []
        self.labels = []
        self.wnid2name = {}
        self.name2wnid = {}
        meta_mat_path = os.path.join(root_dir, "ILSVRC2012_devkit_t12/data/meta.mat")       
        meta = scipy.io.loadmat(meta_mat_path, squeeze_me=True)["synsets"]
        wnids = meta["WNID"].tolist()
        class_names = meta["words"]
        for wnid, name in zip(wnids, class_names):
            self.wnid2name[wnid] = name
            self.name2wnid[name] = wnid

        class_wnid_filename = 'data/imagenet_10_classes/wnids.txt'
        with open(class_wnid_filename, 'rt') as input_file:
            wnids_10 = [wnid.strip() for wnid in input_file.readlines()]

        print(f'Loading images {split} ... ')
        if split == 'train':
            for dirname in tqdm(wnids_10):
                count = 0
                for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                    image = os.path.join(self.data_dir, dirname, filename)
                    # print('image', image)
                    if not debug and self.transform:
                        image_obj = Image.open(image)
                        image = self.transform(image_obj)
                    self.images.append(image)
                    if label2id is None:
                        self.labels.append(self.wnid2name[dirname])
                    else:
                        self.labels.append(label2id[self.wnid2name[dirname]])
                    if self.data_size != -1:
                        count += 1
                        if count == self.data_size:
                            break
        else: # 'val'
            val_ground_truth_path = os.path.join(root_dir, 
                                     "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
            with open(val_ground_truth_path) as input_file:
                val_ground_truth = [wnids[int(idx.strip()) - 1]  \
                                    for idx in input_file.readlines()]
            idxs = [i for i, label in enumerate(val_ground_truth) \
                    if label in wnids_10]
            count = 0
            all_val_filenames = list(os.listdir(self.data_dir))
            # for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
            count = 0
            for i in tqdm(idxs):
                filename = all_val_filenames[i]
                if count == 0:
                    print(filename)
                    count += 1
                image = os.path.join(self.data_dir, filename)
                # print('image', image)
                if not debug and self.transform:
                    image_obj = Image.open(image)
                    image = self.transform(image_obj)
                self.images.append(image)
                if label2id is None:
                    self.labels.append(self.wnid2name[val_ground_truth[i]])
                else:
                    self.labels.append(label2id[self.wnid2name[val_ground_truth[i]]])
                if self.data_size != -1:
                    count += 1
                    if count == self.data_size:
                        break

        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        # import pdb
        # pdb.set_trace()
        return image, label
    

class MaskedImageNetDatasetMedium(Dataset):
    def __init__(self, root_dir, mask_dir, split='train', transform=None, 
                 mask_transform=None, debug=False, need_cls=False, data_size=-1,
                 label2id=None):
        need_cls = True
        self.split = split
        self.root_dir = root_dir 
        self.mask_dir = mask_dir 
        
        if split == 'train':
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_train')
            self.mask_dir_split = os.path.join(root_dir, 
                                                   'ILSVRC2012_img_train' + mask_dir)
        else: # val
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_val')
            self.mask_dir_split = os.path.join(root_dir, 
                                                   'ILSVRC2012_img_val'  + mask_dir)
        
        self.transform = transform
        self.mask_transform = mask_transform
        self.debug = debug
        self.data_size = data_size
        self.label2id = label2id

        self.images = []
        self.masks = []
        self.labels = []
        self.wnid2name = {}
        self.name2wnid = {}
        meta_mat_path = os.path.join(root_dir, "ILSVRC2012_devkit_t12/data/meta.mat")       
        meta = scipy.io.loadmat(meta_mat_path, squeeze_me=True)["synsets"]
        wnids = meta["WNID"].tolist()
        class_names = meta["words"]
        for wnid, name in zip(wnids, class_names):
            self.wnid2name[wnid] = name
            self.name2wnid[name] = wnid

        class_wnid_filename = 'data/imagenet_10_classes/wnids.txt'
        with open(class_wnid_filename, 'rt') as input_file:
            wnids_10 = [wnid.strip() for wnid in input_file.readlines()]

        class_idxs = [wnids.index(wnid) for wnid in wnids_10]

        if split == 'train':
            print(f'Loading images {split} ... ')
            # for dirname in tqdm(os.listdir(self.data_dir)):
            for dirname in wnids_10:
                count = 0
                if os.path.isdir(self.mask_dir_split):
                    for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                        image_path = os.path.join(self.data_dir, dirname, filename)
                        # print('image', image)
                        image, label = self.get_images_labels(image_path, 
                                                              debug, 
                                                              dirname)
                        self.images.append(image)
                        self.labels.append(label)
                        
                        if mask_dir.endswith('h5'):
                            mask_file_path = os.path.join(self.mask_dir_split, 
                                                        dirname, 
                                                        filename + '.h5')
                        else:
                            mask_file_path = os.path.join(self.mask_dir_split, 
                                                        dirname, 
                                                        filename + '.pkl')
                        if not debug and self.mask_transform:
                            if mask_dir.endswith('h5'):
                                with h5py.File(mask_file_path, 'r') as input_file:
                                    masks_i = input_file['masks'][:]
                            else:
                                with open(mask_file_path, 'rb') as input_file:
                                    masks_i = pickle.load(input_file)
                            mask = torch.tensor(masks_i)
                            mask = self.mask_transform(mask)  # can take top k
                            self.masks.append(mask)
                        else:
                            self.masks.append(mask_file_path)

                        if self.data_size != -1:
                            count += 1
                            if count == self.data_size:
                                break
                else:
                    with h5py.File(self.mask_dir_split, 'r') as input_file:
                        group = input_file[dirname]
                        for filename in group.keys():
                            image_path = os.path.join(self.data_dir, dirname, filename)
                            # print('image', image)
                            image, label = self.get_images_labels(image_path, 
                                                                  debug, 
                                                                  dirname)
                            self.images.append(image)
                            self.labels.append(label)
                            
                            if not debug and self.mask_transform:
                                masks_i = group[filename][:]
                                mask = torch.tensor(masks_i)
                                mask = self.mask_transform(mask)  # can take top k
                                self.masks.append(mask)
                            else:
                                self.masks.append((dirname, filename))
                            if self.data_size != -1:
                                count += 1
                                if count == self.data_size:
                                    break
                    
        else: # 'val'
            val_ground_truth_path = os.path.join(root_dir, 
                                     "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
            with open(val_ground_truth_path) as input_file:
                val_ground_truth = [wnids[int(idx.strip()) - 1]  \
                                    for idx in input_file.readlines()]
            
            # for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
            idxs = [i for i, label in enumerate(val_ground_truth) \
                    if label in wnids_10]
            count = 0
            all_val_filenames = list(os.listdir(self.data_dir))
            # for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
            
            if os.path.isdir(self.mask_dir_split):
                for i in idxs:
                    filename = all_val_filenames[i]

                    # Get image
                    image_path = os.path.join(self.data_dir, filename)

                    image, label = self.get_images_labels(image_path, 
                                                          debug, 
                                                          val_ground_truth[i])
                    self.images.append(image)
                    self.labels.append(label)
                
                    if mask_dir.endswith('h5'):
                        mask_file_path = os.path.join(self.mask_dir_split, filename + '.h5')
                    else:
                        mask_file_path = os.path.join(self.mask_dir_split, filename + '.pkl')
                    if not debug and self.mask_transform:
                        if mask_dir.endswith('h5'):
                            with h5py.File(mask_file_path, 'r') as input_file:
                                masks_i = input_file['masks'][:]
                        else:
                            with open(mask_file_path, 'rb') as input_file:
                                masks_i = pickle.load(input_file)
                        mask = torch.tensor(masks_i)
                        mask = self.mask_transform(mask)  # can take top k
                        self.masks.append(mask)
                    else:
                        self.masks.append(mask_file_path)
                        
                    if self.data_size != -1:
                        count += 1
                        if count == self.data_size:
                            break
            else:
                with h5py.File(self.mask_dir_split, 'r') as input_file:
                    for i in idxs:
                        filename = all_val_filenames[i]

                        # Get image
                        image_path = os.path.join(self.data_dir, filename)

                        image, label = self.get_images_labels(image_path, 
                                                              debug, 
                                                              val_ground_truth[i])
                        self.images.append(image)
                        self.labels.append(label)
                        
                        if not debug and self.mask_transform:
                            masks_i = input_file[filename][:]
                            mask = torch.tensor(masks_i)
                            mask = self.mask_transform(mask)  # can take top k
                            self.masks.append(mask)
                        else:
                            self.masks.append(tuple([filename]))
                        if self.data_size != -1:
                            count += 1
                            if count == self.data_size:
                                break

        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        mask = self.masks[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        if self.debug and self.mask_transform:
            mask_file_path = mask
            if isinstance(mask, tuple):
                with h5py.File(self.mask_dir_split, 'r') as input_file:
                    if len(mask_file_path) == 2:
                        dirname, filename = mask_file_path
                        if filename in input_file[dirname]:
                            masks_i = input_file[dirname][filename][:]
                            if len(masks_i) == 0:
                                masks_i = np.ones((image.shape[1], 
                                                   image.shape[2]))
                        else:
                            masks_i = np.ones((image.shape[1], image.shape[2]))
                    else:  # == 1
                        filename = mask_file_path[0]
                        if filename in input_file:
                            masks_i = input_file[filename][:]
                            if len(masks_i) == 0:
                                masks_i = np.ones((image.shape[1], 
                                                   image.shape[2]))
                        else:
                            masks_i = np.ones((image.shape[1], 
                                               image.shape[2]))
            elif self.mask_dir.endswith('h5'):
                with h5py.File(mask_file_path, 'r') as input_file:
                    masks_i = input_file['masks'][:]
                if len(masks_i) == 0:
                    masks_i = np.ones((1, image.shape[1], image.shape[2])).bool()
            else:
                with open(mask_file_path, 'rb') as input_file:
                    masks_i = pickle.load(input_file)
            try:
                mask = torch.tensor(masks_i)
            except:
                print('mask_file_path failed:', mask_file_path)
                mask = torch.tensor(masks_i)
            mask = self.mask_transform(mask)  # can take top k
        # import pdb
        # pdb.set_trace()
        return image, label, mask
    
    def get_images_labels(self, image_path, debug, label_name):
        # Get image
        # print('image', image)
        if not debug and self.transform:
            image_obj = Image.open(image_path)
            image = self.transform(image_obj)
        else:
            image = image_path

        if self.label2id is None:
            label = self.wnid2name[label_name]
        else:
            label = self.label2id[self.wnid2name[label_name]]
        
        return image, label
    

class ImageNetDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, debug=False, 
                 need_cls=False, data_size=-1, label2id=None):
        need_cls = True
        debug = True
        self.split = split
        self.root_dir = root_dir 
        if split == 'train':
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_train')
        else: # val
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_val')
        self.transform = transform
        self.debug = debug
        self.data_size = data_size
        self.label2id = label2id

        self.images = []
        self.labels = []
        self.wnid2name = {}
        self.name2wnid = {}
        meta_mat_path = os.path.join(root_dir, "ILSVRC2012_devkit_t12/data/meta.mat")       
        meta = scipy.io.loadmat(meta_mat_path, squeeze_me=True)["synsets"]
        wnids = meta["WNID"]
        class_names = meta["words"]
        for wnid, name in zip(wnids, class_names):
            self.wnid2name[wnid] = name
            self.name2wnid[name] = wnid

        if split == 'train':
            print(f'Loading images {split} ... ')
            for dirname in tqdm(os.listdir(self.data_dir)):
            # for dirname in self.class_idxs:
                count = 0
                for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                    image = os.path.join(self.data_dir, dirname, filename)
                    # print('image', image)
                    if not debug and self.transform:
                        image_obj = Image.open(image)
                        image = self.transform(image_obj)
                    self.images.append(image)
                    if label2id is None:
                        self.labels.append(self.wnid2name[dirname])
                    else:
                        self.labels.append(label2id[self.wnid2name[dirname]])
                    if self.data_size != -1:
                        count += 1
                        if count == self.data_size:
                            break
        else: # 'val'
            val_ground_truth_path = os.path.join(root_dir, 
                                     "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
            with open(val_ground_truth_path) as input_file:
                val_ground_truth = [wnids[int(idx.strip()) - 1]  \
                                    for idx in input_file.readlines()]
            count = 0
            for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
                image = os.path.join(self.data_dir, filename)
                # print('image', image)
                if not debug and self.transform:
                    image_obj = Image.open(image)
                    image = self.transform(image_obj)
                self.images.append(image)
                if label2id is None:
                    self.labels.append(self.wnid2name[val_ground_truth[i]])
                else:
                    self.labels.append(label2id[self.wnid2name[val_ground_truth[i]]])
                if self.data_size != -1:
                    count += 1
                    if count == self.data_size:
                        break

        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        # import pdb
        # pdb.set_trace()
        return image, label
    

class MaskedImageNetDataset(Dataset):
    def __init__(self, root_dir, mask_dir, split='train', transform=None, 
                 mask_transform=None, debug=False, need_cls=False, data_size=-1,
                 label2id=None):
        need_cls = True
        debug = True
        self.split = split
        self.root_dir = root_dir # '/nlp/data/vision_datasets/ImageNet_old'
        self.mask_dir = mask_dir # /nlp/data/vision_datasets/ImageNet_old/ILSVRC2012_img_train_sam_vit_h
        
        if split == 'train':
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_train')
            self.mask_dir_split = os.path.join(root_dir, 
                                                   'ILSVRC2012_img_train' + mask_dir)
        else: # val
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_val')
            self.mask_dir_split = os.path.join(root_dir, 
                                                   'ILSVRC2012_img_val'  + mask_dir)
        
        self.transform = transform
        self.mask_transform = mask_transform
        self.debug = debug
        self.data_size = data_size
        self.label2id = label2id

        self.images = []
        self.masks = []
        self.labels = []
        self.wnid2name = {}
        self.name2wnid = {}
        meta_mat_path = os.path.join(root_dir, "ILSVRC2012_devkit_t12/data/meta.mat")       
        meta = scipy.io.loadmat(meta_mat_path, squeeze_me=True)["synsets"]
        wnids = meta["WNID"].tolist()
        class_names = meta["words"]
        for wnid, name in zip(wnids, class_names):
            self.wnid2name[wnid] = name
            self.name2wnid[name] = wnid

        class_wnid_filename = 'data/imagenet_10_classes/wnids.txt'
        with open(class_wnid_filename, 'rt') as input_file:
            wnids_10 = [wnid.strip() for wnid in input_file.readlines()]

        class_idxs = [wnids.index(wnid) for wnid in wnids_10]

        if split == 'train':
            print(f'Loading images {split} ... ')
            # for dirname in tqdm(os.listdir(self.data_dir)):
            # for dirname in wnids_10:
            # count_db = 0
            for dirname in tqdm(os.listdir(self.data_dir)):
                count = 0
                if os.path.isdir(self.mask_dir_split):
                    for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                        image_path = os.path.join(self.data_dir, dirname, filename)
                        # print('image', image)
                        image, label = self.get_images_labels(image_path, 
                                                              debug, 
                                                              dirname)
                        self.images.append(image)
                        self.labels.append(label)
                        
                        if mask_dir.endswith('h5'):
                            mask_file_path = os.path.join(self.mask_dir_split, 
                                                        dirname, 
                                                        filename + '.h5')
                        else:
                            mask_file_path = os.path.join(self.mask_dir_split, 
                                                        dirname, 
                                                        filename + '.pkl')
                        if not debug and self.mask_transform:
                            if mask_dir.endswith('h5'):
                                with h5py.File(mask_file_path, 'r') as input_file:
                                    masks_i = input_file['masks'][:]
                            else:
                                with open(mask_file_path, 'rb') as input_file:
                                    masks_i = pickle.load(input_file)
                            mask = torch.tensor(masks_i)
                            mask = self.mask_transform(mask)  # can take top k
                            self.masks.append(mask)
                        else:
                            self.masks.append(mask_file_path)

                        if self.data_size != -1:
                            count += 1
                            if count == self.data_size:
                                break
                else:
                    with h5py.File(self.mask_dir_split, 'r') as input_file:
                        group = input_file[dirname]
                        for filename in group.keys():
                            image_path = os.path.join(self.data_dir, dirname, filename)
                            # print('image', image)
                            image, label = self.get_images_labels(image_path, 
                                                                  debug, 
                                                                  dirname)
                            self.images.append(image)
                            self.labels.append(label)
                            
                            if not debug and self.mask_transform:
                                masks_i = group[filename][:]
                                mask = torch.tensor(masks_i)
                                mask = self.mask_transform(mask)  # can take top k
                                self.masks.append(mask)
                            else:
                                self.masks.append((dirname, filename))
                            if self.data_size != -1:
                                count += 1
                                if count == self.data_size:
                                    break
                        # count_db += 1
                        # if count_db == 100:
                        #     break
                    
        else: # 'val'
            val_ground_truth_path = os.path.join(root_dir, 
                                     "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
            with open(val_ground_truth_path) as input_file:
                val_ground_truth = [wnids[int(idx.strip()) - 1]  \
                                    for idx in input_file.readlines()]
            
            # for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
            idxs = [i for i, label in enumerate(val_ground_truth) \
                    if label in wnids_10]
            count = 0
            all_val_filenames = list(os.listdir(self.data_dir))
            # for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
            
            if os.path.isdir(self.mask_dir_split):
                # for i in idxs:
                for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
                    # filename = all_val_filenames[i]

                    # Get image
                    image_path = os.path.join(self.data_dir, filename)

                    image, label = self.get_images_labels(image_path, 
                                                          debug, 
                                                          val_ground_truth[i])
                    self.images.append(image)
                    self.labels.append(label)
                
                    if mask_dir.endswith('h5'):
                        mask_file_path = os.path.join(self.mask_dir_split, filename + '.h5')
                    else:
                        mask_file_path = os.path.join(self.mask_dir_split, filename + '.pkl')
                    if not debug and self.mask_transform:
                        if mask_dir.endswith('h5'):
                            with h5py.File(mask_file_path, 'r') as input_file:
                                masks_i = input_file['masks'][:]
                        else:
                            with open(mask_file_path, 'rb') as input_file:
                                masks_i = pickle.load(input_file)
                        if len(masks_i.shape) == 0:
                            masks_i = np.ones((1,
                                               image.shape[1], 
                                               image.shape[2])).bool()
                        mask = torch.tensor(masks_i)
                        mask = self.mask_transform(mask)  # can take top k
                        self.masks.append(mask)
                    else:
                        self.masks.append(mask_file_path)
                        
                    if self.data_size != -1:
                        count += 1
                        if count == self.data_size:
                            break
            else:
                with h5py.File(self.mask_dir_split, 'r') as input_file:
                    for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
                        # filename = all_val_filenames[i]

                        # Get image
                        image_path = os.path.join(self.data_dir, filename)

                        image, label = self.get_images_labels(image_path, 
                                                              debug, 
                                                              val_ground_truth[i])
                        self.images.append(image)
                        self.labels.append(label)
                        
                        if not debug and self.mask_transform:
                            masks_i = input_file[filename][:]
                            if len(masks_i.shape) == 0:
                                masks_i = np.ones((image.shape[1], 
                                                   image.shape[2]))
                            mask = torch.tensor(masks_i)
                            mask = self.mask_transform(mask)  # can take top k
                            self.masks.append(mask)
                        else:
                            self.masks.append(tuple([filename]))
                        if self.data_size != -1:
                            count += 1
                            if count == self.data_size:
                                break

        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        mask = self.masks[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        if self.debug and self.mask_transform:
            mask_file_path = mask
            if isinstance(mask, tuple):
                with h5py.File(self.mask_dir_split, 'r') as input_file:
                    if len(mask_file_path) == 2:
                        dirname, filename = mask_file_path
                        if filename in input_file[dirname]:
                            masks_i = input_file[dirname][filename][:]
                            if len(masks_i) == 0:
                                masks_i = np.ones((image.shape[1], 
                                                   image.shape[2]))
                        else:
                            masks_i = np.ones((image.shape[1], image.shape[2]))
                    else:  # == 1
                        filename = mask_file_path[0]
                        if filename in input_file:
                            masks_i = input_file[filename][:]
                            if len(masks_i) == 0:
                                masks_i = np.ones((image.shape[1], 
                                                   image.shape[2]))
                        else:
                            masks_i = np.ones((image.shape[1], image.shape[2]))
            elif self.mask_dir.endswith('h5'):
                with h5py.File(mask_file_path, 'r') as input_file:
                    masks_i = input_file['masks'][:]
                if len(masks_i) == 0:
                    masks_i = np.ones((1, image.shape[1], image.shape[2])).bool()
            else:
                with open(mask_file_path, 'rb') as input_file:
                    masks_i = pickle.load(input_file)
            try:
                mask = torch.tensor(masks_i)
            except:
                print('mask_file_path failed:', mask_file_path)
                mask = torch.tensor(masks_i)
            try:
                mask = self.mask_transform(mask)  # can take top k
            except:
                print('mask', mask.shape, mask_file_path)
                mask = self.mask_transform(mask)
        # import pdb
        # pdb.set_trace()
        return image, label, mask
    
    def get_images_labels(self, image_path, debug, label_name):
        # Get image
        # print('image', image)
        if not debug and self.transform:
            image_obj = Image.open(image_path)
            image = self.transform(image_obj)
        else:
            image = image_path

        if self.label2id is None:
            label = self.wnid2name[label_name]
        else:
            label = self.label2id[self.wnid2name[label_name]]
        
        return image, label
    

class MaskedImageNetDatasetOld(Dataset):
    def __init__(self, root_dir, mask_dir, split='train', transform=None, 
                 mask_transform=None, debug=False, need_cls=False, data_size=-1,
                 label2id=None):
        need_cls = True
        self.split = split
        self.root_dir = root_dir # '/nlp/data/vision_datasets/ImageNet_old'
        self.mask_dir = mask_dir # /nlp/data/vision_datasets/ImageNet_old/ILSVRC2012_img_train_sam_vit_h
        
        if split == 'train':
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_train')
            self.mask_dir_split = os.path.join(root_dir, 'ILSVRC2012_img_train' + mask_dir)
        else: # val
            self.data_dir = os.path.join(root_dir, 'ILSVRC2012_img_val')
            self.mask_dir_split = os.path.join(root_dir, 'ILSVRC2012_img_val' + mask_dir)
        self.transform = transform
        self.mask_transform = mask_transform
        self.debug = debug
        self.data_size = data_size
        self.label2id = label2id

        self.images = []
        self.masks = []
        self.labels = []
        self.wnid2name = {}
        self.name2wnid = {}
        meta_mat_path = os.path.join(root_dir, "ILSVRC2012_devkit_t12/data/meta.mat")       
        meta = scipy.io.loadmat(meta_mat_path, squeeze_me=True)["synsets"]
        wnids = meta["WNID"]
        class_names = meta["words"]
        for wnid, name in zip(wnids, class_names):
            self.wnid2name[wnid] = name
            self.name2wnid[name] = wnid

        if split == 'train':
            print(f'Loading images {split} ... ')
            for dirname in tqdm(os.listdir(self.data_dir)):
            # for dirname in self.class_idxs:
                count = 0
                for filename in tqdm(os.listdir(os.path.join(self.data_dir, dirname))):
                    image = os.path.join(self.data_dir, dirname, filename)
                    # print('image', image)
                    if not debug and self.transform:
                        image_obj = Image.open(image)
                        image = self.transform(image_obj)
                    self.images.append(image)
                    if label2id is None:
                        self.labels.append(self.wnid2name[dirname])
                    else:
                        self.labels.append(label2id[self.wnid2name[dirname]])

                    mask_file_path = os.path.join(self.mask_dir_split, 
                                                  dirname, 
                                                  filename + '.pkl')
                    if not debug and self.mask_transform:
                        with open(mask_file_path, 'rb') as input_file:
                            masks_i = pickle.load(input_file)
                        mask = torch.tensor(masks_i)
                        mask = self.mask_transform(mask)  # can take top k
                        self.masks.append(mask)
                    else:
                        self.masks.append(mask_file_path)

                    if self.data_size != -1:
                        count += 1
                        if count == self.data_size:
                            break
        else: # 'val'
            val_ground_truth_path = os.path.join(root_dir, 
                                     "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt")
            with open(val_ground_truth_path) as input_file:
                val_ground_truth = [wnids[int(idx.strip()) - 1]  \
                                    for idx in input_file.readlines()]
            count = 0
            for i, filename in tqdm(enumerate(os.listdir(self.data_dir))):
                image = os.path.join(self.data_dir, filename)
                # print('image', image)
                if not debug and self.transform:
                    image_obj = Image.open(image)
                    image = self.transform(image_obj)
                self.images.append(image)
                if label2id is None:
                    self.labels.append(self.wnid2name[val_ground_truth[i]])
                else:
                    self.labels.append(label2id[self.wnid2name[val_ground_truth[i]]])

                mask_file_path = os.path.join(self.mask_dir_split, filename + '.pkl')
                if not debug and self.mask_transform:
                    with open(mask_file_path, 'rb') as input_file:
                        masks_i = pickle.load(input_file)
                    mask = torch.stack([torch.tensor(mask['segmentation'])
                                        for mask in masks_i])
                    mask = self.mask_transform(mask)  # can take top k
                    self.masks.append(mask)
                else:
                    self.masks.append(mask_file_path)
                    
                if self.data_size != -1:
                    count += 1
                    if count == self.data_size:
                        break

        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        mask = self.masks[idx]
        if self.debug and self.transform:
            image_obj = Image.open(image)
            image = self.transform(image_obj)
        if self.debug and self.mask_transform:
            mask_file_path = mask
            with open(mask_file_path, 'rb') as input_file:
                masks_i = pickle.load(input_file)
            
            mask = torch.tensor(masks_i)
            if len(mask) == 0:
                mask = torch.ones(image.shape[-2], image.shape[-1]).bool().unsqueeze(0)
            mask = self.mask_transform(mask)  # can take top k
            
        # import pdb
        # pdb.set_trace()
        return image, label, mask