
# Importing the base-networks used for meta-learning.
from experiments.resources.models.linear import LinearRegression
from experiments.resources.models.mlp import MultiLayerPerceptron
from experiments.resources.models.alexnet import AlexNet
from experiments.resources.models.allcnnc import AllCNNC
from experiments.resources.models.lenet5 import LeNet5
from experiments.resources.models.preresnet import PreResNet18
from experiments.resources.models.preresnet import PreResNet34
from experiments.resources.models.preresnet import PreResNet50
from experiments.resources.models.preresnet import PreResNet101
from experiments.resources.models.preresnet import PreResNet152
from experiments.resources.models.pyramidnet import PyramidNet
from experiments.resources.models.resnet import ResNet18
from experiments.resources.models.resnet import ResNet34
from experiments.resources.models.resnet import ResNet50
from experiments.resources.models.resnet import ResNet101
from experiments.resources.models.resnet import ResNet152
from experiments.resources.models.squeezenet import SqueezeNet
from experiments.resources.models.vgg import VGG11
from experiments.resources.models.vgg import VGG13
from experiments.resources.models.vgg import VGG16
from experiments.resources.models.vgg import VGG19
from experiments.resources.models.wideresnet import WideResNet404
from experiments.resources.models.wideresnet import WideResNet168
from experiments.resources.models.wideresnet import WideResNet2810

# Importing the meta-learning datasets.
from experiments.resources.datasets import Diabetes
from experiments.resources.datasets import CaliforniaHousing
from experiments.resources.datasets import Wine
from experiments.resources.datasets import Crime
from experiments.resources.datasets import MNIST
from experiments.resources.datasets import SVHN
from experiments.resources.datasets import CIFAR10
from experiments.resources.datasets import CIFAR100

# Importing utility functions for running experiments.
from experiments.resources.parser import register_configurations
from experiments.resources.parser import override_configurations
from experiments.resources.metrics import MultiErrorRate
from experiments.resources.metrics import BinaryErrorRate
from experiments.resources.exporter import export_results
from experiments.resources.exporter import export_model
from experiments.resources.exporter import export_offline_loss
from experiments.resources.exporter import export_online_loss

import inspect
import torch


def match_signature(func):
    """ Matches the given arguments to the function signature. """
    def wrapped_func(*args, **kwargs):
        return func(*args, **{key: value for (key, value) in kwargs.items()
                              if key in inspect.signature(func).parameters})
    return wrapped_func


dataset_archive = {
    "diabetes": {"data": Diabetes, "config": "experiments/resources/configurations/diabetes_config.yaml"},
    "california": {"data": CaliforniaHousing, "config": "experiments/resources/configurations/california_config.yaml"},
    "wine": {"data": Wine, "config": "experiments/resources/configurations/wine_config.yaml"},
    "crime": {"data": Crime, "config": "experiments/resources/configurations/crime_config.yaml"},
    "mnist": {"data": MNIST, "config": "experiments/resources/configurations/mnist_config.yaml"},
    "cifar10": {"data": CIFAR10, "config": "experiments/resources/configurations/cifar10_config.yaml"},
    "cifar100": {"data": CIFAR100, "config": "experiments/resources/configurations/cifar100_config.yaml"},
    "svhn": {"data": SVHN, "config": "experiments/resources/configurations/svhn_config.yaml"}
}


model_archive = {
    "linear": LinearRegression,
    "mlp": MultiLayerPerceptron,
    "alexnet": AlexNet,  # 23,272,266
    "allcnnc": AllCNNC,  # 1,372,254
    "lenet5": LeNet5,
    "preresnet18": PreResNet18,
    "preresnet34": PreResNet34,
    "preresnet50": PreResNet50,
    "preresnet101": PreResNet101,
    "preresnet152": PreResNet152,
    "pyramidnet": PyramidNet,
    "resnet18": ResNet18,  # 11,173,962
    "resnet34": ResNet34,  # 21,282,122
    "resnet50": ResNet50,  # 23,520,842
    "resnet101": ResNet101,  # 42,512,970
    "resnet152": ResNet152,  # 58,156,618
    "squeezenet": SqueezeNet,
    "vgg11": VGG11,  # 9,231,114
    "vgg13": VGG13,  # 9,416,010
    "vgg16": VGG16,  # 14,728,266
    "vgg19": VGG19,  # 20,040,522
    "wrn40-4": WideResNet404,  # 8,972,340
    "wrn16-8": WideResNet168,  # 11,007,540
    "wrn28-10": WideResNet2810  # 36,536,884
}

objective_archive = {
    "multierrorrate": MultiErrorRate(),
    "binaryerrorrate": BinaryErrorRate(),
    "nllloss": torch.nn.NLLLoss(),
    "bceloss": torch.nn.BCELoss(),
    "mseloss": torch.nn.MSELoss(),
    "celoss": torch.nn.CrossEntropyLoss(),
}

optimizer_archive = {
    "sgd": match_signature(torch.optim.SGD),
    "adam": match_signature(torch.optim.Adam)
}
