import argparse
import pickle

import torch
import torch.nn as nn

parser = argparse.ArgumentParser()
parser.add_argument(
    '--checkpoint',
    type=str,
    default=None
)
parser.add_argument(
    '--swin',
    action="store_true"
)
args = parser.parse_args()


if __name__ == '__main__':
    ckpt = torch.load(args.checkpoint, map_location='cpu')['model']

    if args.swin:
        # Vanilla Swin
        for k in list(ckpt.keys()):
            ckpt['backbone.' + k] = ckpt[k]
            del ckpt[k]

    # Edit M2F training output
    # Do nothing

    save_dict = {
        'model' : ckpt,
        '__author__' : 'rw435'
    }

    with open(args.checkpoint.split('.pth')[0] + '.pkl', 'wb') as fp:
        pickle.dump(save_dict, fp, protocol=-1)

