
from torchvision.models import EfficientNet_B0_Weights,GoogLeNet_Weights,MobileNet_V2_Weights,ResNet18_Weights
from .mi_models.GoogleNet import CSTA_GoogleNet

def set_model_from_config(config):
    model = set_model(
        model_name=config.model_name,
        Scale=config.Scale,
        Softmax_axis=config.Softmax_axis,
        Balance=config.Balance,
        Positional_encoding=config.Positional_encoding,
        Positional_encoding_shape=config.Positional_encoding_shape,
        Positional_encoding_way=config.Positional_encoding_way,
        Dropout_on=config.Dropout_on,
        Dropout_ratio=config.Dropout_ratio,
        Classifier_on=config.Classifier_on,
        CLS_on=config.CLS_on,
        CLS_mix=config.CLS_mix,
        mask_type=config.mask_type,
        read_out=config.read_out,
        key_value_emb=config.key_value_emb,
        Skip_connection=config.Skip_connection,
        Layernorm=config.Layernorm
    )
    return model

def set_model(model_name,
              Scale,
              Softmax_axis,
              Balance,
              Positional_encoding,
              Positional_encoding_shape,
              Positional_encoding_way,
              Dropout_on,
              Dropout_ratio,
              Classifier_on,
              CLS_on,
              CLS_mix,
              mask_type,
              read_out,
              key_value_emb,
              Skip_connection,
              Layernorm):
    if model_name in ['GoogleNet','GoogleNet_Attention']:
        model = CSTA_GoogleNet(
            model_name=model_name,
            Scale=Scale,
            Softmax_axis=Softmax_axis,
            Balance=Balance,
            Positional_encoding=Positional_encoding,
            Positional_encoding_shape=Positional_encoding_shape,
            Positional_encoding_way=Positional_encoding_way,
            Dropout_on=Dropout_on,
            Dropout_ratio=Dropout_ratio,
            Classifier_on=Classifier_on,
            CLS_on=CLS_on,
            CLS_mix=CLS_mix,
            mask_type=mask_type,
            read_out=read_out,
            key_value_emb=key_value_emb,
            Skip_connection=Skip_connection,
            Layernorm=Layernorm
        )
        state_dict = GoogLeNet_Weights.IMAGENET1K_V1.get_state_dict(progress=False)
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith('aux')}
        new_state_dict = model.googlenet.state_dict()
        for name,param in state_dict.items():
            new_state_dict[name] = param
        model.googlenet.load_state_dict(new_state_dict)
    else:
        raise
    return model