import argparse
import json
import time

import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import PeftModel, prepare_model_for_kbit_training

def parse_args():
    parser = argparse.ArgumentParser(description="LLaMA with LoRA inference script")
    parser.add_argument("--base_model", type=str, required=True, help="Path to the base model")
    parser.add_argument("--lora_model", type=str, required=True, help="Path to the LoRA model")
    parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset to use")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
    return parser.parse_args()

args = parse_args()


base_model_name = args.base_model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map="auto",
    torch_dtype=torch.float16,
)

# Prepare For LoRA
base_model = prepare_model_for_kbit_training(base_model)


data = load_dataset(args.dataset_name)


lora_model_name = args.lora_model
model = PeftModel.from_pretrained(base_model, lora_model_name)


model.eval()


batch_size = args.batch_size  # 定义每批推理的样本数量


start_time = time.time()
time_costs = []


with torch.no_grad():
    for batch_start in tqdm(range(0, batch_size * 50, batch_size)):
        # 准备批量输入
        prompts = [
            data["train"][i]["instruction"] + data["train"][i]["input"]
            for i in range(batch_start, min(batch_start + batch_size, len(data["train"])))
        ]

        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")


        outputs = model.generate(
            inputs.input_ids,
            max_length=1024,
            num_beams=5,  # Beam Search
            top_k=50,  #
            # skip_special_tokens=True,
            # clean_up_tokenization_spaces=True
        )

        time_costs.append(time.time() - start_time)

with open(lora_model_name + "/speed.json", "w+") as file:
    result = {
        "peak_memory": torch.cuda.max_memory_allocated(),
        "memory": torch.cuda.memory_allocated(),
        "time_costs": time_costs
    }
    result = json.dumps(result)
    file.write(result)
