import argparse
import pickle

import numpy as np
import torch
import torch.nn as nn

parser = argparse.ArgumentParser()
parser.add_argument(
    '--checkpoint',
    type=str,
    default=None
)
args = parser.parse_args()


if __name__ == '__main__':
    with open(args.checkpoint, 'rb') as fp:
        save_dict = pickle.load(fp, encoding='latin1')
    
    ckpt = save_dict['model']
    for k in list(ckpt.keys()):
        if k == 'backbone.patch_embed.proj.weight':
            # import ipdb; ipdb.set_trace()
            patch_weights = ckpt[k]

            # Zeros init
            new_patch_weights = patch_weights
            mask_dim = torch.zeros((patch_weights.shape[0], 1, patch_weights.shape[2], patch_weights.shape[3]))
            
            # # Swin init
            # # import ipdb; ipdb.set_trace()
            # new_patch_weights = patch_weights
            # C, D, K1, K2 = patch_weights.shape
            # assert K1 == K2, 'need square patches'
            # proj = nn.Conv2d(4, C, kernel_size=K1, stride=K1)
            # mask_dim = proj.weight[:, -1, :, :].unsqueeze(dim=1).detach().cpu().numpy()

            # # Xiaolong init
            # # import ipdb; ipdb.set_trace()
            # C, D, K1, K2 = patch_weights.shape
            # assert K1 == K2, 'need square patches'
            # # Change patch weights
            # mask_dim = np.sum(patch_weights, axis=1, keepdims=True) / 4
            # new_patch_weights = patch_weights * 3. / 4

            new_patch_weights = np.concatenate([new_patch_weights, mask_dim], axis=1)
            del ckpt[k]
            ckpt[k] = new_patch_weights
    
    save_dict = {
        'model' : ckpt,
        '__author__' : 'rw435'
    }
    # torch.save(ckpt, args.checkpoint.split('.pth')[0] + '.pkl')

    with open(args.checkpoint.split('.pkl')[0] + '_masked.pkl', 'wb') as fp:
        pickle.dump(save_dict, fp, protocol=-1)
