import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from .memfolder import ImageMemFolder
from . import download_blob


def get_dataset(dataset_name, data_dir, split, transform=None, imsize=None, bucket='bucket2', **kwargs):

  if dataset_name in ['imagenet', 'cifar10', 'cifar100', 'svhn', 'mnist']:
    dataset = globals()[f'get_{dataset_name}'](dataset_name, data_dir, split, transform=transform, imsize=imsize, bucket=bucket, **kwargs)
  else:
    dataset = get_imageFolderDataset(dataset_name, data_dir, split, transform=transform, imsize=imsize, bucket=bucket, **kwargs)

  item = dataset.__getitem__(0)[0]
  dataset.nchannels = item.size(0)
  dataset.imsize = item.size(1)

  return dataset


def get_aug(split, imsize=None, aug='large'):
  if aug == 'large':
    imsize = imsize if imsize is not None else 224
    if split == 'train':
      return [transforms.RandomResizedCrop(imsize, scale=(0.2, 1.0))]
    else:
      return [transforms.Resize(round(imsize * 1.143)), transforms.CenterCrop(imsize)]
  else:
    imsize = imsize if imsize is not None else 32
    if split == 'train':
      return [transforms.RandomCrop(imsize, padding=round(imsize / 8))]
    else:
      return [transforms.Resize(imsize), transforms.CenterCrop(imsize)]


def get_transform(split, normalize=None, transform=None, imsize=None, aug='large'):
  if transform is None:
    if normalize is None:
      normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transform = transforms.Compose(get_aug(split, imsize=imsize, aug=aug)
                                   + [transforms.ToTensor(), normalize])
  return transform


def get_imageFolderDataset(dataset_name, data_dir, split, transform=None, bucket='bucket2', **kwargs):
  local_dir = os.path.join(data_dir, dataset_name)
  if not os.path.exists(data_dir):
    os.mkdir(data_dir)
  if not os.path.exists(local_dir):
    os.mkdir(local_dir)

  source_path = os.path.join(dataset_name, split)
  destination_path = os.path.join(local_dir, split)
  download_blob(bucket, source_path + '.pkl', destination_path + '.pkl')

  dataset = ImageMemFolder(destination_path, transform=transform, **kwargs)
  return dataset


def get_imagenet(dataset_name, data_dir, split, transform=None, imsize=None, bucket='bucket2', **kwargs):
  transform = get_transform(split, normalize=None, transform=transform, imsize=imsize, aug='large')
  return get_imageFolderDataset(dataset_name, data_dir, split, transform=transform, bucket=bucket, **kwargs)


def get_cifar10(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
  transform = get_transform(split, transform=transform, imsize=imsize, aug='small')
  return datasets.CIFAR10(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs)


def get_cifar100(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
  transform = get_transform(split, transform=transform, imsize=imsize, aug='small')
  return datasets.CIFAR100(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs)


def get_svhn(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
  transform = get_transform(split, transform=transform, imsize=imsize, aug='small')
  split = 'train' if split == 'train' else 'test'
  dataset = datasets.SVHN(data_dir, split=split, transform=transform, download=True, **kwargs)
  dataset.classes = [f'{i}' for i in range(10)]
  return dataset

def get_mnist(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs):
  normalize = transforms.Normalize(mean=[0.131], std=[0.289])
  transform = get_transform(split, normalize=normalize, transform=transform, imsize=imsize, aug='small')
  return datasets.MNIST(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs)

