import torch
from torchvision.models import ResNet50_Weights, resnet50, convnext_base, ConvNeXt_Base_Weights
from torchvision.models import (vgg16, VGG16_Weights, 
                                efficientnet_v2_s, EfficientNet_V2_S_Weights, 
                                convnext_base, ConvNeXt_Base_Weights)

def load_moco(model_path="moco_v2_800ep_pretrain.pth.tar"):
    checkpoint = torch.load(model_path)["state_dict"]
    for k in list(checkpoint.keys()):
        # retain only encoder_q up to before the embedding layer
        if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
            # remove prefix
            checkpoint[k[len("module.encoder_q."):]] = checkpoint[k]
        # delete renamed or unused k
        del checkpoint[k]

    model = resnet50()
    model.load_state_dict(checkpoint, strict=False)
    return model

def moco_v2_model_target_layer():
    model = load_moco()
    cls_model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
    return model, model.layer4[-1], cls_model

def swav_model_target_layer():
    model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
    cls_model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
    return model, model.layer4[-1], cls_model

def vicreg_resnet_model_target_layer():
    model = torch.hub.load('facebookresearch/vicregl:main', 'resnet50_alpha0p9')
    cls_model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
    return model, model.layer4[-1], cls_model

def vicreg_convnext_model_target_layer():
    model = torch.hub.load('facebookresearch/vicregl:main', 'convnext_base_alpha0p9')
    cls_model = convnext_base(weights = ConvNeXt_Base_Weights.IMAGENET1K_V1)
    return model, model.stages[3][2].dwconv, cls_model

def barlowtwins_model_target_layer():
    model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
    cls_model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
    return model, model.layer4[-1], cls_model

def resnet50_model_target_layer():
    model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V1)
    return model, model.layer4[-1]

def convnext_base_model_target_layer():
    model = convnext_base(weights = ConvNeXt_Base_Weights.IMAGENET1K_V1)
    return model, model.features

def vgg16_model_target_layer():
    model = vgg16(weights = VGG16_Weights.IMAGENET1K_V1)
    return model, model.features