from typing import cast
from torch import nn

from ..create import ModelConfig
from ..access import get_last_layer, set_last_layer
# from .layer_accessors import get_last_layer_accessor


def disable_gradients(model: nn.Module) -> nn.Module:
    for param in model.parameters():
        param.requires_grad = False
    # Set the previous layers to eval mode to prevent the BatchNorm
    # statistics from updating during fine-tuning
    for module in model.modules():
        module.eval()
    return model

# Based on https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
def prepare_torchvision_finetuning(
    config: ModelConfig,
    model: nn.Module,
    feature_extract: bool,
) -> nn.Module:
    """Prepare a torchvision model for finetuning or feature extraction.

    Parameters 
    ---------
    model_name: TorchvisionModelName
        The id of the model family, e.g. 'resnet', 'alexnet', 'vgg', etc.
    model: torch.nn.Module
        The model instance that should be finetuned.
        This must be a model from torchvision.models corresponding to
        the provided model_name.
    num_classes: int
        The number of classes of the new fully-connected final layer
    feature_extract: bool
        If True, all layers except the new last layer will be frozen.
        If False, the entire model has gradients enabled for finetuning.

    Returns
    -------
    model: torch.nn.Module
        The model with a new final layer and disabled gradients except for
        the final layer if feature_extract is set to True

    Raises
    ------
    ValueError if model_name is invalid.
    """

    if feature_extract:
        model_ft = disable_gradients(model)
    else:
        model_ft = model

    last_layer = get_last_layer(config, model_ft)
    num_features = cast(int, last_layer.in_features)
    new_last_layer = nn.Linear(num_features, config.num_classes)
    return set_last_layer(config, model_ft, layer=new_last_layer)

    # last_layer_accessor = get_last_layer_accessor(model_name)
    # last_layer = last_layer_accessor(model)

    # if model_name.startswith("resnet"):
    #     """ Resnet18
    #     """
    #     num_ftrs = last_layer.in_features
    #     model_ft.fc = nn.Linear(num_ftrs, num_classes)
    #     # input_size = 224
    # elif model_name == "alexnet":
    #     """ Alexnet
    #     """
    #     num_ftrs = last_layer.in_features
    #     model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
    #     # input_size = 224
    # elif model_name.startswith("vgg"):
    #     """ VGG11_bn
    #     """
    #     num_ftrs = last_layer.in_features
    #     model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
    #     input_size = 224

    # elif model_name.startswith("squeezenet"):
    #     """ Squeezenet
    #     """
    #     model_ft.classifier[1] = nn.Conv2d(
    #         512, num_classes, kernel_size=(1,1), stride=(1,1)
    #     )
    #     model_ft.num_classes = num_classes
    #     # input_size = 224
    # elif model_name.startswith("densenet"):
    #     """ Densenet
    #     """
    #     num_ftrs = last_layer.in_features
    #     model_ft.classifier = nn.Linear(num_ftrs, num_classes)
    #     # input_size = 224
    # elif model_name == "inception":
    #     """ Inception v3
    #     Be careful, expects (299,299) sized images and has auxiliary output
    #     """
    #     # Handle the auxilary net
    #     num_ftrs = model_ft.AuxLogits.fc.in_features
    #     model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
    #     # Handle the primary net
    #     num_ftrs = model_ft.fc.in_features
    #     model_ft.fc = nn.Linear(num_ftrs,num_classes)
    #     # input_size = 299
    # else:
    #     raise ValueError(f"Invalid model name {model_name}")
    # return model
