# import torch
# from safetensors.torch import save_file,load_file

# # 指定输入和输出文件路径
# bin_file_path = "/datanfs2/medllava/llava/Externalization_llava/checkpoints/llava-v1.5-7b-lora-June10-textvqa-bs8-extraction-1e-4/llava-lora/lora_trainables.bin"  # 替换为你的 .bin 文件路径
# safetensors_file_path = "/datanfs2/medllava/llava/Externalization_llava/checkpoints/llava-v1.5-7b-lora-June10-textvqa-bs8-extraction-1e-4/llava-lora/adapter_model.safetensors"  # 输出的 .safetensors 文件路径
# bin_file_path_good = "/datanfs2/medllava/llava/basellava/LLaVA-main/checkpoints/llava-v1.5-lora-unlearning-textvqa/adapter_model.safetensors"  # 替换为你的 .bin 文件路径
# # 读取 .bin 文件
# state_dict = torch.load(bin_file_path, map_location="cpu")
# i = 0
# j = 0
# new_state_dict = {}
# for key in list(state_dict.keys()):
#     j += 1
#     # print("bad_model",key)
#     if "base_model.base_model." in key:
#         i += 1
        
#         new_key = key.replace("base_model.base_model.", "base_model.")
#         new_key = new_key.replace(".lora.", ".")
#         new_state_dict[new_key] = state_dict[key]
#         # if "lora_B" in key:
#         #     print(key,state_dict[key])
#     # print("bad_model",key)
# print(f"Total keys: {j}, Filtered keys: {i}")
# for key in list(new_state_dict.keys()):
#     print("fixed_model", key)
# state_dict_good = load_file(bin_file_path_good)
# print()
# print()
# p = 0
# for key in list(state_dict_good.keys()):
#     # 检查并删除以 "lora_" 开头的键
#     p+=1
#     print("good_model",key)
#     if p ==20:
#         break
# # 保存为 .safetensors 文件
# save_file(new_state_dict, safetensors_file_path)

# print(f"Successfully converted {bin_file_path} to {safetensors_file_path}")
import torch
from safetensors.torch import save_file, load_file
import argparse

def main():
    parser = argparse.ArgumentParser(description="Convert .bin to .safetensors with key renaming.")
    parser.add_argument('--bin_file', type=str, required=True, help='Path to input .bin file')
    parser.add_argument('--safetensors_file', type=str, required=True, help='Path to output .safetensors file')
    parser.add_argument('--compare_file', type=str, default=None, help='Path to reference .safetensors file (optional)')
    args = parser.parse_args()

    state_dict = torch.load(args.bin_file, map_location="cpu")
    i = 0
    j = 0
    new_state_dict = {}
    # for key in list(state_dict.keys()):
    #     print(f"Original model key: {key}")
    # import ipdb;ipdb.set_trace()
    for key in list(state_dict.keys()):
        j += 1
        # print("bad_model", key)
        # if "module." in key:
        i += 1
        new_key = key.replace("language_model.base_model.base_model.model.model.", "base_model.model.model.")
        # new_key = new_key.replace("base_model.base_model.", "base_model.")
        new_key = new_key.replace(".default.", ".")
        # new_key = key
        new_state_dict[new_key] = state_dict[key]
            # print(f"Renaming key: {key} to {new_key}")
    print(f"Total keys: {j}, Filtered keys: {i}")
    # for key in list(new_state_dict.keys()):
        # print("fixed_model", key)
    # import ipdb;ipdb.set_trace()
    if args.compare_file:
        state_dict_good = load_file(args.compare_file)
        print()
        print()
        p = 0
        for key in list(state_dict_good.keys()):
            p += 1
            # print("good_model", key)
            if p == 20:
                break
    # for key in new_state_dict.keys():
    #     # if new_state_dict[key].is_meta:
    #     #     print(f"Warning: {key} is a meta tensor, skipping.")
    #     #     continue
    #     print(f"Final model key: {key}")
    save_file(new_state_dict, args.safetensors_file)
    print(f"Successfully converted {args.bin_file} to {args.safetensors_file}")

if __name__ == "__main__":
    main()