from dataclasses import dataclass
from typing import Literal

from torch.nn import Module
import torchvision

from .cifar import resnet, densenet, vgg


# ModelType = Literal["resnet-18", "densenet-121", "vgg-11"]
# ModelDomain = Literal["cifar", "imagenet"]

@dataclass
class ModelConfig:
    type: str#ModelType
    domain: str#ModelDomain
    num_classes: int

cifar_models = {
    "resnet-18": resnet.ResNet18,
    "densenet-121": densenet.densenet_cifar,
    "vgg-11": vgg.VGG11,
}
imagenet_models = {
    "resnet-18": torchvision.models.resnet18,
    "densenet-121": torchvision.models.densenet121,
    "vgg-11": torchvision.models.vgg11,
}

def create_model(
    config: ModelConfig,
    pretrained: bool = False,
) -> Module:
    if config.domain == "cifar":
        models = cifar_models
        model_args = {"num_classes": config.num_classes}
    elif config.domain == "imagenet":
        models = imagenet_models
        model_args = {
            "num_classes": config.num_classes,
            "pretrained": pretrained,
        }
    else:
        raise ValueError(f"Invalid domain {config.domain}")
    model_type = config.type
    model = models[model_type](**model_args)
    return model
