'''Framework default config'''

from framework.wresnet import resnet110_cifar, resnet32_cifar, resnet20_cifar, resnet44_cifar, resnet56_cifar
from framework.wresnet_no_bn import resnet110_no_bn_cifar, resnet32_no_bn_cifar, resnet20_no_bn_cifar, resnet44_no_bn_cifar, resnet56_no_bn_cifar
from framework.wvgg import vgg16_bn, vgg13_bn, vgg19_bn, vgg13, vgg16, vgg19
from torchvision.models.resnet import resnet50
from framework.mnist_net import MnistNet
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10, SVHN, CIFAR100


from dataset.imagenet_on_philly import imagenet_on_philly
import torch


def get_config():
  config = {
      'root': '...',
#      'root': '/public/share/dataset/image/',
      'basic_net_mnist': 'mnistnet',
      'basic_net_svhn': 'resnet32',
      'basic_net_cifar10': 'resnet32',
      'basic_net_cifar100': 'resnet110',
      'basic_net_imagenet': 'resnet50',
      'num_workers_mnist': 1,
      'num_workers_svhn': 2,
      'num_workers_cifar10': 2,
      'num_workers_cifar100': 2,
      'num_workers_imagenet': 4
      }

  return config


def get_arch(arch):
  if arch == 'resnet20':
    return resnet20_cifar(width=1)
  if arch == '10xresnet20':
    return resnet20_cifar(width=10)
  if arch == 'resnet32':
    return resnet32_cifar(width=1)
  if arch == '2xresnet32':
    return resnet32_cifar(width=2)
  if arch == '5xresnet32':
    return resnet32_cifar(width=5)
  if arch == '10xresnet32':
    return resnet32_cifar(width=10)
  if arch == 'resnet44':
    return resnet44_cifar(width=1)
  if arch == 'resnet56':
    return resnet56_cifar(width=1)
  if arch == 'resnet110':
    return resnet110_cifar(width=1)
  if arch == 'resnet20_no_bn':
    return resnet20_no_bn_cifar(width=1)
  if arch == 'resnet32_no_bn':
    return resnet32_no_bn_cifar(width=1)
  if arch == 'resnet44_no_bn':
    return resnet44_no_bn_cifar(width=1)
  if arch == 'resnet56_no_bn':
    return resnet56_no_bn_cifar(width=1)
  if arch == 'resnet110_no_bn':
    return resnet110_no_bn_cifar(width=1)
  if arch == 'resnet50':
    return resnet50(pretrained=False)
  if arch == 'mnistnet':
    return MnistNet()
  if arch == 'vgg13_bn':
    return vgg13_bn()
  if arch == 'vgg16_bn':
    return vgg16_bn()
  if arch == 'vgg19_bn':
    return vgg19_bn()
  if arch == 'vgg13':
    return vgg13()
  if arch == 'vgg16':
    return vgg16()
  if arch == 'vgg19':
    return vgg19()
  raise NotImplementedError


def get_dataset(dataset, root, transform_train, transform_test):
  print(dataset)
  if dataset == 'cifar10':
    trainset = CIFAR10(
        root=root, train=True, download=True, transform=transform_train)
    trainset_test = CIFAR10(
        root=root, train=True, download=True, transform=transform_test)
    testset = CIFAR10(
        root=root, train=False, download=True, transform=transform_test)
    num_classes = 10
    print('cifar 10 implemented')
  elif dataset == 'cifar100':
    trainset = CIFAR100(
        root=root, train=True, download=True, transform=transform_train)
    trainset_test = CIFAR100(
        root=root, train=True, download=True, transform=transform_test)
    testset = CIFAR100(
        root=root, train=False, download=True, transform=transform_test)
    num_classes = 100
  elif dataset == 'mnist':
    trainset = MNIST(
        root=root, train=True, download=True, transform=transform_train)
    trainset_test = MNIST(
        root=root, train=True, download=True, transform=transform_test)
    testset = MNIST(
        root=root, train=False, download=True, transform=transform_test)
    num_classes = 10
  elif dataset == 'svhn':
    trainset = SVHN(
        root=root, split='train', download=True, transform=transform_train)
    trainset_test = SVHN(
        root=root, split='train', download=True, transform=transform_test)
    testset = SVHN(
        root=root, split='test', download=True, transform=transform_test)
    num_classes = 10
  elif dataset == 'imagenet':
    # trainset = ImageNet(
    #     root, split='train', download=False, transform=transform_train)
    # trainset_test = ImageNet(
    #     root, split='train', download=False, transform=transform_test)
    # testset = ImageNet(
    #     root, split='val', download=False, transform=transform_test)
    # num_classes = 1000

    trainset = imagenet_on_philly(True, transform_train)
    trainset_test = imagenet_on_philly(True, transform_test)
    testset = imagenet_on_philly(False, transform_test)
    num_classes = 1000
  else:
    raise NotImplementedError

  return trainset, trainset_test, testset, num_classes

# remove all the ToTensor() for cifar10
def get_transform(dataset):
  if dataset == 'cifar10':
    default_transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    default_transform_test = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
  elif dataset == 'cifar100':
    default_transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    default_transform_test = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
  elif dataset == 'mnist':
    default_transform_train = transforms.Compose([
        transforms.ToTensor(),
    ])
    default_transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
  elif dataset == 'svhn':
    default_transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    default_transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
  elif dataset == 'imagenet':
    default_transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    default_transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
  else:
    raise NotImplementedError

  return default_transform_train, default_transform_test


def get_pin_memory(dataset):
  return dataset == 'imagenet'


class TupleDataset(torch.utils.data.dataset.Dataset):

  def __init__(self, data, transform=None):
    self.data = data
    self.transform = transform

  def __getitem__(self, index):
    if self.transform is not None:
      return self.transform(self.data[0][index]), self.data[1][index]
    return self.data[0][index], self.data[1][index]

  def __len__(self):
    return self.data[0].shape[0]
