import glob
import os
import os.path
import pathlib
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Callable, Optional, Tuple

import numpy as np
import PIL
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import (check_integrity,
                                        download_and_extract_archive,
                                        download_url, verify_str_arg)


class ImageNetRandom(ImageFolder):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, cls_lst=None):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train

        # self.fpath = os.path.join(root, 'imgnt')
        self.fpath = root

        if not os.path.exists(self.fpath):
            if not download:
                print(self.fpath)
                raise RuntimeError('Dataset not found. You can use download=True to download it')

        if self.train:
            fpath = self.fpath + '/train'
            super().__init__(fpath, transform=transforms.ToTensor() if transform is None else transform, target_transform=target_transform)
            # print(self.__dir__())
            # raise ValueError('stop')
            # self.classes = [i for i in range(1000)]
            # self.class_to_idx = [i for i in range(1000)]


        else:
            fpath = self.fpath + '/val'
            super().__init__(fpath, transform=transforms.ToTensor() if transform is None else transform, target_transform=target_transform)
            # self.classes = [i for i in range(1000)]
            # self.class_to_idx = [i for i in range(1000)]
        
        if cls_lst is None:
            keep = [i for i in range(100)]
        else:
            keep = cls_lst

        temp_cls = []
        temp_dist = {}
        count = 0
        for idx, clss in zip(range(len(self.classes)), self.classes):
            if idx in keep:
                temp_cls.append(clss)
                temp_dist[clss] = count
                count += 1
        temp_img = []
        temp_targets = []
        for img in self.imgs:
            if img[1] in keep:
                temp_img.append((img[0], keep.index(img[1])))
                temp_targets.append(keep.index(img[1]))

        self.classes = temp_cls
        self.class_to_idx = temp_dist
        self.imgs = temp_img
        self.targets = temp_targets

        # raise ValueError(self.classes, self.class_to_idx, self.imgs[:5], len(self.targets) )
        # self.classes = [i for i in range(len(keep))]
        # self.class_to_idx = []
        # if self.dataset == 'imagenetsub':   
        #     idx = [i for i in range(len(self.train_dataset)) if self.train_dataset.imgs[i][1] in keep ]
        #     self.online_iter_dataset = Subset(self.online_iter_dataset, idx)

    def __len__(self):
        return len(self.targets)