from typing import Callable, cast

from torch.nn import Module

from .create import ModelConfig
from .cifar import (
    resnet as cifar_resnet,
    densenet as cifar_densenet,
    vgg as cifar_vgg,
)

LayerAccessor = Callable[[Module], Module]

cifar_penultimate_getters = {
    "resnet-18": cifar_resnet.get_penultimate_layer,
    "densenet-121": cifar_densenet.get_penultimate_layer,
    "vgg-11": cifar_vgg.get_penultimate_layer,
}
imagenet_penultimate_getters = {
    "resnet-18": lambda model: model.avgpool,
    "densenet-121": lambda model: model.features.norm5,
    "vgg-11": lambda model: model.classifier[5],
}

def get_penultimate_layer(model_config: ModelConfig, model: Module) -> Module:
    if model_config.domain == "cifar":
        getters = cifar_penultimate_getters
    else:
        getters = imagenet_penultimate_getters
    model_type = model_config.type
    return getters[model_type](model)

cifar_last_getters = {
    "resnet-18": cifar_resnet.get_last_layer,
    "densenet-121": cifar_densenet.get_last_layer,
    "vgg-11": cifar_vgg.get_last_layer,
}
imagenet_last_getters = {
    "resnet-18": lambda model: model.fc,
    "densenet-121": lambda model: model.features.classifier,
    "vgg-11": lambda model: model.classifier[6],
}

def get_last_layer(model_config: ModelConfig, model: Module) -> Module:
    if model_config.domain == "cifar":
        getters = cifar_last_getters
    else:
        getters = imagenet_last_getters
    model_type = model_config.type
    return getters[model_type](model)

cifar_last_setters = {
    "resnet-18": cifar_resnet.set_last_layer,
    "densenet-121": cifar_densenet.set_last_layer,
    "vgg-11": cifar_vgg.set_last_layer,
}

def resnet_set_last(model: Module, layer: Module) -> Module:
    model.fc = layer
    return model

def densenet_set_last(model: Module, layer: Module) -> Module:
    model.classifier = layer
    return model

def vgg_set_last(model: Module, layer: Module) -> Module:
    model.classifier[6] = layer
    return model

imagenet_last_setters = {
    "resnet-18": resnet_set_last,
    "densenet-121": densenet_set_last,
    "vgg-11": vgg_set_last,
}

def set_last_layer(
    model_config: ModelConfig, model: Module, layer: Module
) -> Module:
    if model_config.domain == "cifar":
        setters = cifar_last_setters
    else:
        setters = imagenet_last_setters
    model_type = model_config.type
    return setters[model_type](model=model, layer=layer)
