from typing import Any, Iterable, Mapping
import numpy as np
import torch
from torch import nn
from .models32 import get_model
from .grudense import GRUDense
from .frn import FilterResponseNorm
from .resnet224 import resnet50

# standard save model function
def savemodel(
    to,
    modelname: str,
    modelargs: Iterable[Any],
    modelkwargs: Mapping[str, Any],
    model: nn.Module,
    **kwargs
) -> None:
    dic = {
        "modelname": modelname,
        "modelargs": tuple(modelargs),
        "modelkwargs": {k: modelkwargs[k] for k in modelkwargs},
        "modelstates": model.state_dict(),
        **kwargs,
    }
    torch.save(dic, to)


# standard load model function
def loadmodel(fromfile, device=torch.device("cpu")):
    dic = torch.load(fromfile, map_location=device)
    model = globals()[dic["modelname"]](
        *dic["modelargs"], **dic.get("modelkwargs", {})
    ).to(device)
    model.load_state_dict(dic.pop("modelstates"))
    return model, dic


def resnet20(outclass: int, input_size: int = 32) -> torch.nn.Module:
    return get_model(
        "resnet20_frn",
        data_info={"num_classes": outclass, "input_size": input_size},
        activation=torch.nn.Identity,
    )
    

def softplus_inv(x: float) -> float:
    return x + np.log(-np.expm1(-x))

def preresnet110(outclass: int, input_size: int = 32) -> torch.nn.Module:
    return get_model(
        "preresnet110_frn",
        data_info={"num_classes": outclass, "input_size": input_size},
        activation=torch.nn.Identity,
    )


def resnet18wide(outclass: int, input_size: int = 32) -> torch.nn.Module:
    return get_model(
        "resnet18",
        data_info={"num_classes": outclass, "input_size": input_size},
    )


def densenet121(outclass: int, input_size: int = 32) -> torch.nn.Module:
    return get_model(
        "densenet121",
        data_info={"num_classes": outclass, "input_size": input_size},
    )


def gru_dense(vocab_size: int, num_classes: int, padding_idx: int) -> GRUDense:
    return GRUDense(vocab_size, num_classes, padding_idx)


def resnet50_imagenet(outclass: int, input_size: int = 224) -> torch.nn.Module:
    return resnet50(activation=nn.Identity, norm_layer=FilterResponseNorm, num_classes = outclass)

# available models
STANDARDMODELS = {
    "resnet20": resnet20,
    "resnet18wide": resnet18wide,
    "preresnet110": preresnet110,
    "densenet121": densenet121,
    "resnet50_imagenet" : resnet50_imagenet,
}

MODELS = STANDARDMODELS
