from .transformations import cifar_transform_tr, cifar_transform_ts
from torchvision.datasets import CIFAR100, CIFAR10
from .vgg import *
from .resnet import *
from .densenet import *

def get_all_transformations():
    return {
        'cifar10_train': cifar_transform_tr,
        'cifar10_test': cifar_transform_ts,
        'cifar100_train': cifar_transform_tr,
        'cifar100_test': cifar_transform_ts
    }


def get_all_datasets(data_dir, transform_tr, transform_ts):
    return {
        'cifar10_dataset': CIFAR10(root=data_dir, train=True, download=True, transform=transform_tr),
        'cifar10_testset': CIFAR10(root=data_dir, train=False, download=False, transform=transform_ts),
        'cifar100_dataset': CIFAR100(root=data_dir, train=True, download=True, transform=transform_tr),
        'cifar100_testset': CIFAR100(root=data_dir, train=False, download=False, transform=transform_ts)
    }


def get_all_networks():
    return {
        'vgg11_origin': vgg_origin,
        'vgg11_mixup': vgg_origin,
        'vgg11_manifoldMixup': vgg_manifold_mixup,
        'vgg11_AMA': vgg_AMA,
        'resnet50_origin': resnet_origin,
        'resnet50_supCon': SupConResNet,
        'resnet50_supLinear': LinearClassifier,
        'resnet50_mixup': resnet_origin,
        'resnet50_manifoldMixup': resnet_manifold_mixup,
        'resnet50_AMA': resnet_AMA,
        'dense_origin': dense_origin,
        'dense_mixup': dense_origin,
        'dense_manifoldMixup': DenseNetManifoldMixup,
        'dense_AMA': dense_AMA
    }