import torch as T
import os
load_path =""
pt = {}
for i in range(4): pt.update(T.load(os.path.join(load_path, f'pytorch_model-0000{i+1}-of-00004.bin'), map_location='cpu'))
print(len(pt))
mllm, unet = {}, {}
for key in pt:
    if 'embed_tokens' in key or 'lm_head' in key or 'edit_head' in key: mllm[key] = pt[key].half()
    elif 'unet' in key: unet[key.replace('unet.', '')] = pt[key].half()
print(len(mllm)), print(len(unet))
T.save(mllm, load_path + '/mllm.pt'), T.save(unet, load_path + '/unet.pt')