import argparse
import os
import torch

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', default=None, type=str)
    parser.add_argument('--src', default='latest.pth', type=str)
    parser.add_argument('--dst', default='peeled.pth', type=str)
    args = parser.parse_args()

    src_path = os.path.join(args.path, args.src)
    dst_path = os.path.join(args.path, args.dst)

    ckpt = torch.load(src_path, map_location='cpu')

    teacher_keys = []
    for name in ckpt['state_dict']:
        if name.startswith('teacher'):
            teacher_keys.append(name)

    for name in teacher_keys:
        ckpt['state_dict'].pop(name)

    torch.save(ckpt, dst_path)
