import argparse
from functools import partial

import torch
import torch.nn as nn

from mask2former.modeling.backbone.vit import VisionTransformer

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
    '--weights',
    type=str,
    default=None,
    help='Path to pretrained MAE ViT weights'
)
parser.add_argument(
    '--new_img_dim',
    type=int,
    default=None,
    help='New (square) image dimensions'
)
parser.add_argument(
    '--save_path',
    type=str,
    default=None,
    help='New checkpoint save path'
)
args = parser.parse_args()

# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed

def vit_base_patch16(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

if __name__ == '__main__':
    ckpt = torch.load(args.weights, map_location='cpu')
    # import ipdb; ipdb.set_trace()
    checkpoint_model = ckpt['model']
    model = vit_base_patch16(img_size=(args.new_img_dim, args.new_img_dim))
    interpolate_pos_embed(model, checkpoint_model)

    # Now add 'backbone' prefix to all weights
    # import ipdb; ipdb.set_trace()
    for k in list(checkpoint_model.keys()):
        checkpoint_model["backbone." + k] = checkpoint_model[k]
        del checkpoint_model[k]

    # Save checkpoint
    torch.save(
        {
            "model" : checkpoint_model
        },
        args.save_path
    )

