import argparse
import os
from timm.models import create_model
# , apply_test_time_pool, load_checkpoint, is_model, list_models
from utils import ROOT_PATH
import torch
from torch import nn
from Normalize import Normalize, TfNormalize
from custommodels import *


MODEL_NAMES = ['resnet18',       # ResNet-18
    'resnet50',       # ResNet-50
    'resnet101',      # ResNet-101
    'vgg19_bn',       # VGG19 with batch normalization
    'densenet121',    # DenseNet-121
    'inception_v3',   # Inception V3
    'vit_tiny_patch16_224', # ViT-Tiny/16
    'vit_small_patch16_224', # ViT-Small/16
    'vit_base_patch16_224', # ViT-Base/16
    'swin_tiny_patch4_window7_224', # Swin-Tiny
    'pvt_v2_b2_li', # pvt_v2_b2_li
    'mobilevit_s' # mobilevit_s
]

WEIGHT_PATHS = {
       'resnet18': './checkpoints/resnet18-5c106cde.pth',
       'resnet50': './checkpoints/resnet50_a1_0-14fe96d1.pth',
       'resnet101': './checkpoints/resnet101_a1h-36d3f2aa.pth',   
       'densenet121': './checkpoints/densenet121_ra-50efcf5c.pth'  # DenseNet-121
}

def remove_module_prefix(state_dict):
    """Remove 'module.' prefix from each key in the state dictionary."""
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace('model.', '')  # Remove the prefix
        new_state_dict[new_key] = value
    return new_state_dict

def get_model(model_name,pre_trained=True,weight_path=None):
        
    if model_name in MODEL_NAMES:   
        model = create_model(
                model_name,
                pretrained=pre_trained,
                num_classes=1000,
                in_chans=3,
                global_pool=None,
                scriptable=False)
        if pre_trained:
                print (f'Loading Model {model_name} with pretrained weights.')
        else:
                # 加载权重文件
                state_dict = torch.load(weight_path)
                # 移除 'module.' 前缀
                adjusted_state_dict = remove_module_prefix(state_dict)
                # adjusted_state_dict = state_dict
                # 加载调整后的权重到模型
                model.load_state_dict(adjusted_state_dict)
                print(f'Loading Model {model_name} with custom weights from {weight_path}.')
    
    return model

def get_custom_model(model_name, num_classes=1000):
    if model_name in ['resnet50']:
        model = CustomResNetModel(model_name = model_name, num_classes=num_classes, pretrained=True)
    elif model_name in ['inception_v3','densenet121']:
        model = CustomInceptionV3Model(model_name = model_name, num_classes=num_classes, pretrained=True)
    elif model_name in ['swin_tiny_patch4_window7_224']:
        model = CustomSwinModel(model_name = model_name, num_classes=num_classes, pretrained=True)
    else:
        model = CustomViTModel(model_name = model_name, num_classes=num_classes, pretrained=True)
    return model