# This file aims to get the lora weight from zero stage checkpoint and load it into the model.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
from safetensors.torch import save_file, load_file
import os
import shutil


import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default=None, help="where to put the output model")
parser.add_argument("--lora_dir", type=str, default=None, help="where to find the lora model")
parser.add_argument("--base_model", type=str, default=None, help="what is the base model")
args = parser.parse_args()

OUTPUT_DIR = args.output_dir
LORA_DIR = args.lora_dir
# make sure it aligns with lora module
BASE_MODEL = args.base_model

# make sure lora is correct
STATE_DICT = os.path.join(LORA_DIR, 'adapter_model.bin')
TENSOR_PATH = os.path.join(LORA_DIR, 'adapter_model.safetensors')

full_state_dict = torch.load(STATE_DICT, map_location='cpu')
save_file(full_state_dict, TENSOR_PATH)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True,
    attn_implementation='eager',
    torch_dtype=torch.bfloat16
)

model_to_merge = PeftModel.from_pretrained(base_model, LORA_DIR)
merged_model = model_to_merge.merge_and_unload()

os.makedirs(os.path.join(OUTPUT_DIR, 'merged_model'), exist_ok=True)
merged_model.save_pretrained(os.path.join(OUTPUT_DIR, 'merged_model'))

copy_files = ['special_tokens_map.json', 'tokenizer_config.json', 'tokenizer.json']
for file in copy_files:
    shutil.copy(os.path.join(LORA_DIR, file), os.path.join(OUTPUT_DIR, 'merged_model'))