from .transformations import cifar_transform_tr, cifar_transform_ts
from torchvision.datasets import CIFAR100, CIFAR10
from .resnet import *

def get_all_transformations():
    return {
        'coarse_cifar10_train': cifar_transform_tr,
        'coarse_cifar10_test': cifar_transform_ts,
        'coarse_cifar100_train': cifar_transform_tr,
        'coarse_cifar100_test': cifar_transform_ts,
        'cifar10_train': cifar_transform_tr,
        'cifar10_test': cifar_transform_ts,
        'cifar100_train': cifar_transform_tr,
        'cifar100_test': cifar_transform_ts
    }


def get_all_datasets(data_dir, transform_tr, transform_ts):
    return {
        'cifar10_dataset': CIFAR10(root=data_dir, train=True, download=True, transform=transform_tr),
        'cifar10_testset': CIFAR10(root=data_dir, train=False, download=False, transform=transform_ts),
        'coarse_cifar10_dataset': CIFAR10(root=data_dir, train=True, download=True, transform=transform_tr),
        'coarse_cifar10_testset': CIFAR10(root=data_dir, train=False, download=False, transform=transform_ts),
        'cifar100_dataset': CIFAR100(root=data_dir, train=True, download=True, transform=transform_tr),
        'cifar100_testset': CIFAR100(root=data_dir, train=False, download=False, transform=transform_ts),
        'coarse_cifar100_dataset': CIFAR100(root=data_dir, train=True, download=True, transform=transform_tr),
        'coarse_cifar100_testset': CIFAR100(root=data_dir, train=False, download=False, transform=transform_ts)
    }

def get_all_networks():
    return {
        'resnet50_origin': SupCEResNet,
        'resnet50_supCon': SupConResNet,
        'resnet50_supLinear': LinearClassifier,
        'resnet50_mixup': SupCEResNet,
        'resnet50_manifoldMixup': resnet_manifold_mixup,
        'resnet50_AMA': resnet_AMA
    }