""" Dataset Factory

Hacked together by / Copyright 2021, Ross Wightman
"""
import os

import torchvision

from compensate.loadCUB200.dataset import CUB


from torchvision import transforms
import torch

from torchvision import datasets
from torch.utils.data import DataLoader
from compensate.loadStanfordDogs.load import load_standford_dog
from compensate.loadMIT64.mit67 import Dataloder as load_min67
from compensate.loadCUB200.cub200 import Dataloder as load_cub200
from compensate.loadCaltech256.dataset import Caltech256 as load_full_Caltech256

from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
try:
    from torchvision.datadsets import Places365
    has_places365 = True
except ImportError:
    has_places365 = False
try:
    from torchvision.datasets import INaturalist
    has_inaturalist = True
except ImportError:
    has_inaturalist = False
try:
    from torchvision.datasets import Caltech256
    has_Caltech256 = True
except ImportError:
    has_Caltech256 = False
from .dataset import IterableImageDataset, ImageDataset

_TORCH_BASIC_DS = dict(
    cifar10=CIFAR10,
    cifar100=CIFAR100,
    mnist=MNIST,
    qmist=QMNIST,
    kmnist=KMNIST,
    fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}

IMAGE_SIZE = [224, 224]
TRAIN_MEAN = [0.48560741861744905, 0.49941626449353244, 0.43237713785804116]
TRAIN_STD = [0.2321024260764962, 0.22770540015765814, 0.2665100547329813]
TEST_MEAN = [0.4862169586881995, 0.4998156522834164, 0.4311430419332438]
TEST_STD = [0.23264268069040475, 0.22781080253662814, 0.26667253517177186]

# path = '/home/qxpineapple/pytorch-image-models-master/data/CUB_200_2011'
path = 'C:/Users/13126/Desktop/pytorch-image-models-master/data/'



def load_CUB200(is_train):
    path = 'C:\\Users\\13126\\Desktop\\pytorch-image-models-master\\data\\CUB_200_2011'

    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(TRAIN_MEAN, TRAIN_STD)
    ])

    test_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(TEST_MEAN, TEST_STD)
    ])
    if is_train:
        train_dataset = CUB(
            path,
            train=True,
            transform=train_transforms,
            target_transform=None
        )
        return train_dataset
    else:
        test_dataset = CUB(
            path,
            train=False,
            transform=test_transforms,
            target_transform=None
        )
        return test_dataset


def _search_split(root, split):
    # look for sub-folder with name of split in root and use that if it exists
    split_name = split.split('[')[0]
    try_root = os.path.join(root, split_name)
    if os.path.exists(try_root):
        return try_root

    def _try(syn):
        for s in syn:
            try_root = os.path.join(root, s)
            if os.path.exists(try_root):
                return try_root
        return root
    if split_name in _TRAIN_SYNONYM:
        root = _try(_TRAIN_SYNONYM)
    elif split_name in _EVAL_SYNONYM:
        root = _try(_EVAL_SYNONYM)
    return root


def create_dataset(
        name,
        root,
        split='validation',
        search_split=True,
        class_map=None,
        load_bytes=False,
        is_training=False,
        download=False,
        batch_size=None,
        repeats=0,
        **kwargs
):
    """ Dataset factory method

    In parenthesis after each arg are the type of dataset supported for each arg, one of:
      * folder - default, timm folder (or tar) based ImageDataset
      * torch - torchvision based datasets
      * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
      * all - any of the above

    Args:
        name: dataset name, empty is okay for folder based datasets
        root: root folder of dataset (all)
        split: dataset split (all)
        search_split: search for split specific child fold from root so one can specify
            `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
        class_map: specify class -> index mapping via text file or dict (folder)
        load_bytes: load data, return images as undecoded bytes (folder)
        download: download dataset if not present and supported (TFDS, torch)
        is_training: create dataset in train mode, this is different from the split.
            For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
        batch_size: batch size hint for (TFDS)
        repeats: dataset repeats per iteration i.e. epoch (TFDS)
        **kwargs: other args to pass to dataset

    Returns:
        Dataset object
    """
    name = name.lower()
    if name.startswith('torch/'):
        name = name.split('/', 2)[-1]
        torch_kwargs = dict(root=root, download=download, **kwargs)
        if name in _TORCH_BASIC_DS:
            ds_class = _TORCH_BASIC_DS[name]
            use_train = split in _TRAIN_SYNONYM
            ds = ds_class(train=use_train, **torch_kwargs)
        elif name == 'inaturalist' or name == 'inat':
            assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
            target_type = 'full'
            split_split = split.split('/')
            if len(split_split) > 1:
                target_type = split_split[0].split('_')
                if len(target_type) == 1:
                    target_type = target_type[0]
                split = split_split[-1]
            if split in _TRAIN_SYNONYM:
                split = '2021_train'
            elif split in _EVAL_SYNONYM:
                split = '2021_valid'
            ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
        elif name == 'places365':
            assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
            if split in _TRAIN_SYNONYM:
                split = 'train-standard'
            elif split in _EVAL_SYNONYM:
                split = 'val'
            ds = Places365(split=split, **torch_kwargs)
        elif name == 'imagenet':
            if split in _EVAL_SYNONYM:
                split = 'val'
            ds = ImageNet(split=split, **torch_kwargs)
        elif name == 'image_folder' or name == 'folder':
            # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
            if search_split and os.path.isdir(root):
                # look for split specific sub-folder in root
                root = _search_split(root, split)
            ds = ImageFolder(root, **kwargs)
        elif name == 'caltech256-30' or name == 'caltech256-60' or name == 'caltech256-full':
            train_transforms = transforms.Compose([
                transforms.RandomResizedCrop(IMAGE_SIZE),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(TRAIN_MEAN, TRAIN_STD)
            ])
            test_transforms = transforms.Compose([
                transforms.Resize(IMAGE_SIZE),
                transforms.ToTensor(),
                transforms.Normalize(TEST_MEAN, TEST_STD)
            ])
            if is_training:
                if name == 'caltech256-30':
                    load_path = os.path.join(os.path.expanduser(""), 'data/Caltech-256', 'train_30')
                    ds = datasets.ImageFolder(load_path, transform=train_transforms)
                elif name == 'caltech256-60':
                    load_path = os.path.join(os.path.expanduser(""), 'data/Caltech-256', 'train_60')
                    ds = datasets.ImageFolder(load_path, transform=train_transforms)
                elif name == 'caltech256-full':
                    load_path = os.path.join(os.path.expanduser(""), 'data/Caltech-256', 'train_full')
                    ds = datasets.ImageFolder(load_path, transform=train_transforms)
            else:
                load_path = os.path.join(os.path.expanduser(""), 'data/Caltech-256', 'val')
                ds = datasets.ImageFolder(load_path, transform=test_transforms)
        elif name == 'caltech256':
            train_dir = './data/archive/caltech256/train'
            val_dir = './data/archive/caltech256/val'
            classes = sorted(os.listdir(train_dir))  # class names

            transform = transforms.Compose([
                transforms.Resize(size=300),  
                transforms.RandomHorizontalFlip(p=0.5),  #
                transforms.RandomVerticalFlip(p=0.3),  
                transforms.ToTensor(),  
                transforms.RandomApply(torch.nn.ModuleList([transforms.RandomCrop(size=227)]), p=0.5),
                transforms.Resize(size=(224, 224)),
                transforms.RandomErasing(p=0.4)

            ])
            if is_training:
                train_set = load_full_Caltech256(train_dir, classes, transform=transform)
                ds = train_set
            else:
                val_set = load_full_Caltech256(val_dir, classes, transform=transform)
                ds = val_set
        elif name == 'cub-200':
            dataset_path = os.path.join(os.path.expanduser(""), 'data/CUB_200_2011')
            split_info1 = os.path.join(dataset_path, "images.txt")
            split_info2 = os.path.join(dataset_path, "train_test_split.txt")
            train_dataset, test_dataset = load_cub200(dataset_path, split_info1, split_info2).getloader()
            if is_training:
                ds = train_dataset
            else:
                ds = test_dataset
        elif name == 'stanford-dogs':
            train_dataset, test_dataset, classes = load_standford_dog("stanford_dogs")
            if is_training:
                ds = train_dataset
            else:
                ds = test_dataset
        elif name == 'mit67':
            dataset_path = os.path.join(os.path.expanduser(""), 'data/MIT67')
            train_source = os.path.join(dataset_path, "TrainImages.txt")
            test_source = os.path.join(dataset_path, "TestImages.txt")
            train_dataset, test_dataset = load_min67(dataset_path,train_source,test_source).getloader()
            if is_training:
                ds = train_dataset
            else:
                ds = test_dataset
        else:
            assert False, f"Unknown torchvision dataset {name}"
    elif name.startswith('tfds/'):
        ds = IterableImageDataset(
            root, parser=name, split=split, is_training=is_training,
            download=download, batch_size=batch_size, repeats=repeats, **kwargs)
    else:
        # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
        if search_split and os.path.isdir(root):
            # look for split specific sub-folder in root
            root = _search_split(root, split)
        ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
    return ds
