import torch


flow_ckpt = torch.load('two_stream/ckpt/motion/pretrain/pretrain_best.pth')

flow_dict = flow_ckpt['state_dict'].copy()
#rgb_dict = rgb_ckpt['state_dict'].copy()

# delete linear layer weight
del flow_dict['module.model.10.weight']
del flow_dict['module.model.10.bias']
#del rgb_dict['module.model.10.weight']
#del rgb_dict['module.model.10.bias']


for key in list(flow_dict.keys()):
    k = str(key)
    key_list = k.split('.')
    key_list[1] = 'encoder'
    k = '.'.join(key_list)
    flow_dict[k] = flow_dict.pop(key)

"""
for key in list(rgb_dict.keys()):
    k = str(key)
    key_list = k.split('.')
    key_list[1] = 'encoder'
    k = '.'.join(key_list)
    rgb_dict[k] = rgb_dict.pop(key)
"""
torch.save({'flow_encoder_state_dict': flow_dict}, 'two_stream/ckpt/motion/flow_encoder_weight_no_five_crop.pth')
