from backbones import Conv3, Conv4
from backbones import MLP2
from backbones import build_resnet12
from data.dataset_utils import DatasetEnum
from backbones import maml_nets

from enum import Enum

class BackboneEnum(Enum):
    CONV4 = "conv4"
    RESNET12 = "resnet12"
    CONV3 = "conv3"
    MLP2 = "mlp2"


DEFAULT_CLASS_NUM = 5  # 5-way classification

FEATURE_DIM_DICT = {
    # conv4
    (BackboneEnum.CONV4.name, DatasetEnum.MINI_IMAGENET.name): 1600,
    (BackboneEnum.CONV4.name, DatasetEnum.MetaDataset.name): 1600,
}


def get_embedding_dim(backbone_name, ds_name):
    return FEATURE_DIM_DICT[(backbone_name, ds_name)]


def get_backbone(backbone_name, ds_name, config):

    embedding_dim = get_embedding_dim(backbone_name, ds_name)
    backbone = None
    class_num = config.get("class_num", DEFAULT_CLASS_NUM)
    if backbone_name == BackboneEnum.CONV3.name:
        backbone = Conv3(embedding_dim=embedding_dim)
    elif backbone_name == BackboneEnum.CONV4.name:
        backbone = Conv4(embedding_dim=embedding_dim, num_classes=class_num)
    elif backbone_name == BackboneEnum.MLP2.name:
        input_dim = config[ds_name]["input_dim"]
        backbone = MLP2(input_dim)
    elif backbone_name == BackboneEnum.RESNET12.name:
        backbone = build_resnet12()
    else:
        raise ValueError("unknown backbone: {}".format(backbone_name))

    return backbone, embedding_dim

def get_backbone_of_maml(backbone_name, ds_name, config):
    embedding_dim = get_embedding_dim(backbone_name, ds_name)
    backbone = None
    class_num = config.get("class_num", DEFAULT_CLASS_NUM)
    if backbone_name == BackboneEnum.CONV3.name:
        backbone = maml_nets.Conv3(embedding_dim=embedding_dim)
    elif backbone_name == BackboneEnum.CONV4.name:
        backbone = maml_nets.Conv4(embedding_dim=embedding_dim, out_class_num=class_num)
    elif backbone_name == BackboneEnum.MLP2.name:
        input_dim = config[ds_name]["input_dim"]
        backbone = maml_nets.MLP2(input_dim)
    else:
        raise ValueError("unknown backbone: {}".format(backbone_name))

    return backbone, embedding_dim
