# coding=utf-8
import os
import random
import torch
import numpy as np
import scipy.io as io
from torchvision import datasets, transforms
from .image_augment import CIFAR10Policy, Cutout


def random_spilt(root_dir, train_radio):
    train_items = []
    test_items = []
    for (root, dirs, files) in os.walk(root_dir):
        dirs.sort()

        shuffled_indices = np.random.permutation(len(files))
        train_size = int(len(files) * train_radio)
        train_indices = shuffled_indices[:train_size]
        # test_indices = shuffled_indices[train_size:]
        for i, f in enumerate(files):
            if f.endswith("png") or f.endswith("mat"):
                r = root.split('/')
                lr = len(r)
                if i in train_indices:
                    train_items.append((f, r[lr - 2] + "/" + r[lr - 1], root))
                else:
                    test_items.append((f, r[lr - 2] + "/" + r[lr - 1], root))
    print("== TrainSet %d items, TestSet %d items" % (len(train_items), len(test_items)))
    return train_items, test_items


def index_classes(items):
    idx = {}
    for i in items:
        if i[1] not in idx:
            idx[i[1]] = len(idx)
    print("== Found %d classes" % len(idx))
    return idx

















def load_dataset(name, root, cutout=False, auto_aug=False):
    num_class, normalize, train_data, test_data = None, None, None, None
    train_transform = []
    if name == 'CIFAR10' or name == 'CIFAR100':
        train_transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    if auto_aug:
        train_transform.append(CIFAR10Policy())
    train_transform.append(transforms.ToTensor())
    if cutout:
        train_transform.append(Cutout(n_holes=1, length=16))
    if name == 'CIFAR10':
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        num_class = 10
    elif name == 'CIFAR100':
        normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        num_class = 100
    elif name == 'MNIST':
        normalize = transforms.Normalize((0.1307,), (0.3081,))
        num_class = 10
    train_transform.append(normalize)
    train_transform = transforms.Compose(train_transform)
    val_transform = transforms.Compose([transforms.ToTensor(),
                                        normalize
                                        ])
    if name == 'CIFAR100':
        train_data = datasets.CIFAR100(root=root, train=True, download=True,
                                       transform=train_transform)
        val_data = datasets.CIFAR100(root=root, train=False, download=True,
                                     transform=val_transform)
    elif name == 'CIFAR10':
        train_data = datasets.CIFAR10(root=root, train=True, download=True,
                                      transform=train_transform)
        val_data = datasets.CIFAR10(root=root, train=False, download=True,
                                    transform=val_transform)
    elif name == 'MNIST':
        train_data = datasets.MNIST(root=root, train=True, download=True,
                                    transform=train_transform)
        val_data = datasets.MNIST(root=root, train=False, download=True,
                                  transform=val_transform)
    elif name == 'imagenet':
        num_class = 1000

        traindir = os.path.join(root, 'train')
        valdir = os.path.join(root, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

        train_data = datasets.ImageFolder(root=traindir, transform=transform_train)
        val_data = datasets.ImageFolder(root=valdir, transform=transform_test)

    else:
        raise NotImplementedError()
    return train_data, val_data, num_class
