# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import os
import os.path
import sys
import json
import scipy
import numpy as np

import pathlib
from typing import Any, Callable, Optional, Tuple

from PIL import Image

from torch.utils.data import Dataset
from torchvision.datasets import VisionDataset
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
from torchvision.datasets.utils import verify_str_arg

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform


def build_dataset(is_train, args):
    transform = build_transform(is_train, args)

    if args.data_set == 'CIFAR10':
        dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True)
        nb_classes = 10
    
    elif args.data_set == 'MiniIMNET':
        dataset = miniImagenet(root=args.data_path, train=is_train, transform=transform)
        nb_classes = 100
    
    elif args.data_set == 'TinyIMNET':
        dataset = TinyImageNet(root=args.data_path, train=is_train, transform=transform)
        nb_classes = 200
        
    elif args.data_set == 'CIFAR100':
        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
        nb_classes = 100
        
    elif args.data_set == 'IMNET':
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 1000

    elif args.data_set == 'INAT19':
        dataset = INatDataset(args.data_path, train=is_train, year=2019,
                              category=args.inat_category, transform=transform)
        nb_classes = dataset.nb_classes

    elif args.data_set == 'food101':
        dataset = datasets.Food101(root=args.data_path, split='train' if is_train else 'test', transform=transform)
        nb_classes = 101

    elif args.data_set == 'flowers102':
        dataset = Flowers(args.data_path, train=True if is_train else False, transform=transform)
        nb_classes = 102

    elif args.data_set == 'stanfordcar':
        dataset = StanfordCars(args.data_path, split='train' if is_train else 'test', transform=transform)
        nb_classes = 196
    
    elif args.data_set == 'cub200':
        dataset = CUB200(args.data_path, train=True if is_train else False, transform=transform)
        nb_classes = 200

    # elif args.data_set == 'dtd':
    #     dataset = datasets.DTD(args.data_path, train=is_train, transform=transform)

    return dataset, nb_classes


def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int(args.input_size / args.eval_crop_ratio)
        t.append(
            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)
