import os
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
# from model_gen import generate_reply

# args
import argparse
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--model', type=str, default="/default_path")
parser.add_argument('--dataset', type=str, default="Ar-BeaverTails.json")
parser.add_argument('--save_path', type=str, default="/default_path")
args = parser.parse_args()
print(f"Processing:\nmodel = {args.model}\ndataset = {args.dataset}\nsave_path = {args.save_path}")

def generate_reply(model, tokenizer, prompt_input):
    input_ids = tokenizer.encode(f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt_input}\n\n### Response:", return_tensors='pt').to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>"),
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=4096,
        eos_token_id=terminators,
        pad_token_id=128001,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )

    response = outputs[0][input_ids.shape[-1]:]

    return tokenizer.decode(response, skip_special_tokens=True)

model_id = args.model

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

dataset_path = args.dataset
with open(dataset_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# data = data[:3]

prompts = [item['trans_prompt']['dst'] for item in data]

results = []
for prompt in tqdm(prompts, desc="Generating replies"):
    output = generate_reply(model, tokenizer, prompt)
    results.append({"prompt": prompt, "output": output})

model_name = model_id.split("/")[-1]

if not os.path.exists(args.save_path):
    os.makedirs(args.save_path)
save_path = f"{args.save_path}/{model_name}.json"

with open(save_path, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)

print(f"Results are saved in {save_path}")