import os
import json
import random

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, DatasetFolder, default_loader

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.transforms import _pil_interp
from timm.data import create_transform

from typing import Any, Callable, cast, Dict, List, Optional, Tuple
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    return filename.lower().endswith(extensions)

def make_subsampled_dataset(
        directory, class_to_idx, extensions=None,is_valid_file=None, sampling_ratio=1., nb_classes=None):

    instances = []
    directory = os.path.expanduser(directory)
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
    for i, target_class in enumerate(sorted(class_to_idx.keys())):
        if nb_classes is not None and i>=nb_classes:
            break
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        num_imgs = int(len(os.listdir(target_dir))*sampling_ratio)
        imgs=0
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if imgs==num_imgs :
                    break
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)
                    imgs+=1
    return instances




class SubsampledDatasetFolder(DatasetFolder):

    def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None, sampling_ratio=1., nb_classes=None):

        super(DatasetFolder, self).__init__(root, transform=transform,
                                            target_transform=target_transform)
        
        classes, class_to_idx = self._find_classes(self.root)
        samples = make_subsampled_dataset(self.root, class_to_idx, extensions, is_valid_file, sampling_ratio=sampling_ratio, nb_classes=nb_classes)

        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    # __getitem__ and __len__ inherited from DatasetFolder


class ImageNetDataset(SubsampledDatasetFolder):
    def __init__(self, root, loader=default_loader, is_valid_file=None,  **kwargs):
        super(ImageNetDataset, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
                                              is_valid_file=is_valid_file, **kwargs)
        self.imgs = self.samples


def build_dataset(is_train, args):
    transform = build_transform(is_train, args)
    root = args.data_path
    if args.data_set == 'CIFAR10':
        dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True)
        nb_classes = 10
    elif args.data_set == 'CIFAR100':
        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
        nb_classes = 100
    elif args.data_set == 'IMNET':
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = ImageNetDataset(root, transform=transform,
                                  sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
        nb_classes = args.nb_classes if args.nb_classes is not None else 1000
    elif args.data_set.startswith('IMNET-C'):
        method, severity = args.data_set.split('-')[-2].lower(), args.data_set.split('-')[-1]
        dataset = ImageNetDataset(root, transform=transform,
                                  sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
        nb_classes = args.nb_classes if args.nb_classes is not None else 1000
    elif args.data_set == 'IMNET-A':
        dataset = ImageNetDataset(root, transform=transform,
                                  sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
        nb_classes = args.nb_classes if args.nb_classes is not None else 1000
    elif args.data_set == 'IMNET-R':
        dataset = ImageNetDataset(root, transform=transform,
                                  sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
        nb_classes = args.nb_classes if args.nb_classes is not None else 1000
    elif args.data_set == 'IMNET-SK':
        dataset = ImageNetDataset(root, transform=transform,
                                  sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
        nb_classes = args.nb_classes if args.nb_classes is not None else 1000
    elif args.data_set == 'IMNET-V2':
        dataset = ImageNetDataset(root, transform=transform,
                                  sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
        nb_classes = args.nb_classes if args.nb_classes is not None else 1000
    else:
        raise NotImplementedError

    return dataset, nb_classes


def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            no_aug = args.no_aug,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=_pil_interp(args.train_interpolation),
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * args.input_size)
        t.append(
            transforms.Resize(size, interpolation=_pil_interp(3)),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)
