
import torch
import re

ckpt = torch.load('./LaVIN-7b-lite/quant_params/llama-7b-w4a16g128.pth')
updated_dict={}
for i in range(len(ckpt)):
    input_dict = ckpt[i]
    output_dict  = {}
    print("===== layer {} =====".format(i))
    for key, value in input_dict.items():
        # 使用正则表达式进行替换
        if 'self_attn' in key:
            new_key = re.sub('self_attn', 'attention', key)
            if 'k_proj' in key:
                new_key = re.sub('k_proj', 'wk', new_key)
            elif 'v_proj' in key:
                new_key = re.sub('v_proj', 'wv', new_key)
            elif 'q_proj' in key:
                new_key = re.sub('q_proj', 'wq', new_key)
            elif 'o_proj' in key:
                new_key = re.sub('o_proj', 'wo', new_key)
        elif 'mlp' in key:
            new_key = re.sub('mlp', 'feed_forward', key)
            if 'down_proj' in key:
                new_key = re.sub('down_proj', 'w2', new_key)
            elif 'up_proj' in key:
                new_key = re.sub('up_proj', 'w3', new_key)
            elif 'gate_proj' in key:
                new_key = re.sub('gate_proj', 'w1', new_key)

        output_dict[new_key] = value
        print("{} ======> {}".format(key, new_key))
    updated_dict[i] = output_dict

torch.save(updated_dict,'./LaVIN-7B-lite/quant_params/llama-7b-w4a16g128_modify.pth')
