import torch
import os
import sys
from collections import OrderedDict

if __name__ == "__main__":
    base_pth = sys.argv[1]
    backbone_pth = sys.argv[2]

    base = torch.load(base_pth, map_location='cpu')
    backbone = torch.load(backbone_pth, map_location='cpu')

    model_params = base['state_dict']
    backbone_params = backbone['model']

    new_params = OrderedDict()
    for k, v in backbone_params.items():
        k = k.replace('model.', 'model.net.visual_encoder.backbone.')
        new_params[k] = v
    
    for k, v in model_params.items():
        if 'backbone' in k:
            continue
        new_params[k] = v
    
    # Save the new model
    base['state_dict'] = new_params
    torch.save(base, sys.argv[3])