
import os
import warnings
warnings.filterwarnings("ignore")

from models.med import BertConfig, BertModel
from transformers import BertTokenizer

import torch
from torch import nn
import torch.nn.functional as F

from urllib.parse import urlparse
from timm.models.hub import download_cached_file


def init_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('./checkpoints/hfl_bert_wwm')
    tokenizer.add_special_tokens({'bos_token': '[DEC]'})
    tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
    return tokenizer

def is_url(url_or_filename):
    parsed = urlparse(url_or_filename)
    return parsed.scheme in ("http", "https")

def interpolate_pos_embedding(pos_embed_checkpoint, visual_encoder):        
    # interpolate position embedding
    embedding_size = pos_embed_checkpoint.shape[-1]
    # print('embedding_size', embedding_size)
    num_patches = visual_encoder.embeddings.num_patches
    # print('model patch size', num_patches)
    num_extra_tokens = visual_encoder.embeddings.position_embedding.weight.shape[-2] - num_patches
    # print('num_extra_tokens', num_extra_tokens)
    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
    # print('orig_size', orig_size)
    # height (== width) for the new position embedding
    new_size = int(num_patches ** 0.5)
    # print('new_size', new_size)
    if orig_size!=new_size:
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:num_extra_tokens,:]
        # print('extra_tokens', extra_tokens.shape)
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[num_extra_tokens:,:]
        # print('pos_embed_checkpoint', pos_embed_checkpoint.shape)
        # print('pos_tokens', pos_tokens.shape)
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        pos_tokens = torch.squeeze(pos_tokens)
        # print('pos_tokens', pos_tokens.shape)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
        # print('new_pos_embed, ', new_pos_embed.shape)
        # print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
        
        return new_pos_embed    
    else:
        return pos_embed_checkpoint

def load_checkpoint(model, url_or_filename):
    if is_url(url_or_filename):
        cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
        checkpoint = torch.load(cached_file, map_location='cpu')
    elif os.path.isfile(url_or_filename):
        checkpoint = torch.load(url_or_filename, map_location='cpu')
    else:
        raise RuntimeError('checkpoint url or path is invalid')

    state_dict = checkpoint['model']

    if 'visual_encoder.embeddings.position_embedding.weight' in model.state_dict().keys():
        print('visual_encoder.embeddings.position_embedding.weight')
        state_dict['visual_encoder.embeddings.position_embedding.weight'] = interpolate_pos_embedding(state_dict['visual_encoder.embeddings.position_embedding.weight'],
                                                                       model.visual_encoder)
    if 'visual_encoder_m.embeddings.position_embedding.weight' in model.state_dict().keys():
        state_dict['visual_encoder_m.embeddings.position_embedding.weight'] = interpolate_pos_embedding(state_dict['visual_encoder_m.embeddings.position_embedding.weight'],
                                                                         model.visual_encoder_m)
    for key in model.state_dict().keys():
        if key in state_dict.keys():
            if state_dict[key].shape != model.state_dict()[key].shape:
                del state_dict[key]

    msg = model.load_state_dict(state_dict, strict=False)
    print('load checkpoint from %s' % url_or_filename)
    return model, msg
