import argparse
import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import PeftModel, prepare_model_for_kbit_training, load_peft_weights


def parse_args():
    parser = argparse.ArgumentParser(description="LLaMA with LoRA inference script")
    parser.add_argument("--base_model", type=str, required=True, help="Path to the base model")
    parser.add_argument("--lora_model", type=str, required=True, help="Path to the LoRA model")
    return parser.parse_args()

args = parse_args()


base_model_name = args.base_model  # 替换为具体模型路径
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map="auto",  #
    torch_dtype=torch.float16,  #
)

# Prepare For LoRA
base_model = prepare_model_for_kbit_training(base_model)

#
lora_model_name = args.lora_model  #
model = PeftModel.from_pretrained(base_model, lora_model_name)

model.eval()

lora_weights_list = []

results = {
    "name": lora_model_name,
    "max_sever_num": 0,
    "stage_memory": {}
}
with torch.no_grad():
    while True:
        try:
            prompt = "today is very hot, I only want to stay in the room instead of outing. Please recommend some places with strong air condition."
            inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
            generate_ids = model.generate(input_ids=inputs, max_length=100)
            output = tokenizer.batch_decode(generate_ids[0], skip_special_tokens=True,
                                            clean_up_tokenization_spaces=True)

            for i in range(10):
                print(f"{lora_model_name} adding {len(lora_weights_list)} weights to GPU serving ......")
                adapters_weights = load_peft_weights(lora_model_name, device="cuda")
                lora_weights_list.append(adapters_weights)

                n = len(lora_weights_list)
                if n % 100 == 0:
                    results["stage_memory"][f"{n}"] = torch.cuda.memory_allocated() / 1024 / 1024

            print(
                f"{lora_model_name} current num of lora weights: {len(lora_weights_list)}, "
                f"memory allocated: {torch.cuda.memory_allocated() / 1024 / 1024} MB")

        except:
            print("Out Of Memory ! eist. ")
            with open(lora_model_name + "/result.json", "w+") as file:
                results["max_sever_num"] = len(lora_weights_list)
                result_str = json.dumps(results)
                file.write(result_str)
            break
