import torch
import torchvision
from torch.hub import load_state_dict_from_url

def _replace_fc(model, output_dim):
    d = model.fc.in_features
    model.fc = torch.nn.Linear(d, output_dim)
    return model


SIMCLR_RN50_URL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/model_final_checkpoint_phase799.torch"
BARLOWTWINS_RN50_URL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch"
MOCO_RN50_URL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch"
SWAV_RN50_URL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/model_final_checkpoint_phase799.torch"
BYOL_RN50_URL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/baselines/converted_byol_pretrain_res200w2.torch"

DINO_VIT_RL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/dino_300ep_deitsmall16/model_final_checkpoint_phase299.torch"

model_url = {
    'simclr-in': SIMCLR_RN50_URL,
    'barlow-in': BARLOWTWINS_RN50_URL,
    'moco-in': MOCO_RN50_URL,
    'swav-in': SWAV_RN50_URL ,
    'byol-in': BYOL_RN50_URL
}

def replace_module_prefix(state_dict, prefix, replace_with=""):
    state_dict = {
        (key.replace(prefix, replace_with, 1)
         if key.startswith(prefix) else key): val
        for (key, val) in state_dict.items()
    }
    return state_dict


def get_torchvision_state_dict(url):
    model = load_state_dict_from_url(url)
    model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"]

    return replace_module_prefix(model_trunk, "_feature_blocks.")


def imagenet_resnet50_simclr(output_dim=None):
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = torch.nn.Identity()
    model.load_state_dict(get_torchvision_state_dict(SIMCLR_RN50_URL))
    model.fc.in_features = 2048
    if output_dim is None:  # return featurizer
        return model
    return _replace_fc(model, output_dim)

def imagenet_resnet50_pt(model_name,output_dim=None):
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = torch.nn.Identity()
    model.load_state_dict(get_torchvision_state_dict(model_url[model_name]))
    model.fc.in_features = 2048
    if output_dim is None:  # return featurizer
        return model
    return _replace_fc(model, output_dim)


def imagenet_resnet50_barlowtwins(output_dim=None):
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = torch.nn.Identity()
    # import vissl
    model.load_state_dict(get_torchvision_state_dict(BARLOWTWINS_RN50_URL))
    model.fc.in_features = 2048
    if output_dim is None:  # return featurizer
        return model
    return _replace_fc(model, output_dim)


def resnet50_featurizer(config, model='resnet50'):
    from models.initializer import initialize_model
    model = initialize_model(config, d_out=2048)
    rep_dim = model.fc.in_features
    model.fc = torch.nn.Identity()
    model.fc.in_features = rep_dim
    return model


def get_enc(config, weights_name, get_rep_dim=False):
    if weights_name == 'clip':
        import clip
        model, _ = clip.load('RN50', 'cuda')
        return (model.visual, model.visual.attnpool.c_proj.out_features) if get_rep_dim else model.visual
    
    model_dir = {'simclr-in': imagenet_resnet50_simclr,
                'barlow-in': imagenet_resnet50_barlowtwins}
    
    if weights_name in model_dir.keys():
        print(f'initialized model weights with {weights_name} from vissl')
        model = model_dir[weights_name]() 
        return (model, model.fc.in_features) if get_rep_dim else model
    
    model = resnet50_featurizer(config) 
    return (model, model.fc.in_features) if get_rep_dim else model

