from .transformations import *
from torchvision.datasets import CIFAR100, CIFAR10
from dataset.cifar import *
from .resnet import *
from .resnet import resnet32_manifold_mixup, resnet32_AMA
from .resnet32SyncWithSupCon import SupConResNet, LinearClassifier
def get_all_transformations():
    return {
        'cifar10_train': cifar10_transform_tr,
        'cifar10_test': cifar10_transform_ts,
        'cifar100_train': cifar100_transform_tr,
        'cifar100_test': cifar100_transform_ts
    }

def get_all_datasets(data_dir, transform_tr, transform_ts, imb_factor):
    return {
        'cifar10_dataset': IMBALANCECIFAR10(root=data_dir, train=True, download=True,
                                transform=transform_tr, imb_factor=imb_factor),
        'cifar10_testset': IMBALANCECIFAR10(root=data_dir, train=False, download=False,
                               transform=transform_ts, imb_factor=1.0),
        'cifar100_dataset': IMBALANCECIFAR100(root=data_dir, train=True, download=True,
                                 transform=transform_tr, imb_factor=imb_factor),
        'cifar100_testset': IMBALANCECIFAR100(root=data_dir, train=False, download=False,
                                transform=transform_ts, imb_factor=1.0)
    }


def get_all_networks():
    return {
        'resnet32_origin': resnet32,
        'resnet32_AMA': resnet32_AMA,
        'resnet32_mixup': resnet32,
        'resnet32_supCon': SupConResNet,
        'resnet32_supLinear': LinearClassifier,
        'resnet32_manifoldMixup': resnet32_manifold_mixup,
    }

