# -*- coding: utf-8 -*-

from typing import List
from functools import partial

import torch.nn as nn
import torch

from .cifar import CIFAR_GEN
from .tiny_imagenet import TINYIMNET_GEN
from .imagenet import IMNET_GEN


__all__ = ["create"]


def create(name: str, protos: List[torch.Tensor]) -> nn.Module:
    if name == "cifar100":
        return CIFAR_GEN(protos)
    elif name == "tinyimagenet200":
        return TINYIMNET_GEN(protos)
    elif name == "imagenet100":
        return IMNET_GEN(protos)
    elif name == "imagenet1000":
        return IMNET_GEN(protos)
