import multiprocessing
import cv2
import numpy as np

from tensorpack.dataflow.dataset import Cifar10, Cifar100, ILSVRC12, \
      ILSVRC12Files, SVHNDigit, Caltech101Silhouettes
from tensorpack import BatchData, QueueInput, StagingInput, RandomMixData
from tensorpack.dataflow import imgaug, AugmentImageComponent, \
      MultiThreadMapData, MultiProcessRunnerZMQ


def get_cifar10(train_or_test, dir, shuffle=True):
  return Cifar10(train_or_test, shuffle=shuffle, dir=dir)


def get_cifar100(train_or_test, dir, shuffle=True):
  return Cifar100(train_or_test, shuffle=shuffle, dir=dir)


def get_imagenet(train_or_test, dir, shuffle=True):
  if train_or_test == 'train':
    return ILSVRC12(dir=dir, name=train_or_test, shuffle=shuffle, dir_structure='original')
  else:
    return ILSVRC12Files(dir=dir, name=train_or_test, shuffle=False, dir_structure='original')


def get_svhn(train_or_test, dir, shuffle=True):
  if train_or_test == 'train' or train_or_test == 'extra':
    return SVHNDigit(train_or_test, data_dir=dir, shuffle=shuffle)
  else:
    return SVHNDigit('test', data_dir=dir, shuffle=False)


def get_caltech(train_or_test, dir, shuffle=True):
  if train_or_test == 'train':
    return Caltech101Silhouettes('train', dir=dir, shuffle=shuffle)
  else:
    return Caltech101Silhouettes('test', dir=dir, shuffle=shuffle)


def get_dataset(config, batch_size=16,
                is_training=True, shuffle=True):
  name = config['name']
  dir = config['path']
  return _get_ds(name, is_training, dir,
                 batch_size=batch_size, shuffle=shuffle)


def _get_ds(name, is_training, dir,
            batch_size=1, shuffle=True):
  """

  Parameters
  ----------
  name : str
    Dataset name.
  is_training : flag
    Whether get train dataset or test.
  dir : str
    Directory of dataset.
  batch_size : int
    Batch size.
  shuffle : bool
    Whether shuffle.
  """
  remainder = not is_training
  if name == 'cifar10':
    train_or_test = 'train' if is_training else 'test'
    ds = get_cifar10(train_or_test, dir, shuffle=shuffle)

    if is_training:
      augmentors = [
        imgaug.CenterPaste((40, 40)),
        imgaug.RandomCrop((32, 32)),
        imgaug.Flip(horiz=True),
        imgaug.Brightness(63),
        imgaug.Contrast((0.2, 1.8)),
      ]
      ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, batch_size, remainder=remainder)

  elif name == 'svhn':
    train_or_test = 'train' if is_training else 'test'
    if is_training:
      d1 = get_svhn('train', dir=dir, shuffle=shuffle)
      d2 = get_svhn('extra', dir=dir, shuffle=shuffle)
      ds = RandomMixData([d1, d2])
      augmentors = [
        imgaug.CenterPaste((38, 38)),
        imgaug.RandomCrop((32, 32)),
        imgaug.Brightness(30),
        imgaug.Contrast((0.5, 1.5)),
      ]
    else:
      augmentors = []
      ds = get_svhn('test', dir=dir, shuffle=False)
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, batch_size, remainder=remainder)

  else:
    raise ValueError('Dataset %s not supported' % name)
  return ds
