# -*- coding: utf-8 -*-
import os
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import shutil


def merge_peft_model(checkpoint_path, base_model_path, output_model_path):
    """
    合并指定的PEFT模型到基础模型并保存结果
    """
    print(f"处理PEFT模型：{checkpoint_path}")

    # 加载基础模型和分词器
    tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
    base_model = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True)
    
    # 加载PEFT模型
    peft_model = PeftModel.from_pretrained(base_model, checkpoint_path)
    
    # 合并参数
    peft_model.merge_and_unload()
    tokenizer_vocab_size = len(tokenizer)
    model_vocab_size = peft_model.get_input_embeddings().weight.shape[0]
    if tokenizer_vocab_size != model_vocab_size:
        print(f"word embedding layer size：{model_vocab_size} -> {tokenizer_vocab_size}")
        peft_model.resize_token_embeddings(tokenizer_vocab_size)
    print(peft_model)
    
    # 保存合并后的模型和分词器
    print(f"保存到：{output_model_path}")
    os.makedirs(output_model_path, exist_ok=True)
    peft_model.base_model.save_pretrained(output_model_path, trust_remote_code=True, safe_serialization=True)
    tokenizer.save_pretrained(output_model_path)
    
    # file_to_copy="/data/sjn/Ph3_begin_end_layer/configuration_phi3.py"
    # destination_file = os.path.join(output_model_path, os.path.basename(file_to_copy))
    # shutil.copy(file_to_copy, destination_file)  # 复制文件
    # print(f"{destination_file}")
    return 0

def main():
    # 创建ArgumentParser对象
    parser = argparse.ArgumentParser(description="合并PEFT模型到基础模型并保存结果")
    
    # 添加命令行参数
    parser.add_argument('--checkpoint_path', type=str, required=True, help='PEFT模型的路径')
    parser.add_argument('--base_model_path', type=str, required=True, help='基础模型的路径')
    parser.add_argument('--output_model_path', type=str, required=True, help='输出模型的路径')
    
    # 解析命令行参数
    args = parser.parse_args()
    
    # 调用合并函数
    merge_peft_model(args.checkpoint_path, args.base_model_path, args.output_model_path)

if __name__ == '__main__':
    main()
