from datasets import load_dataset
import argparse
import json
from pathlib import Path
from tqdm import tqdm
import time
import re
import os
from huggingface_hub import login
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
# login()
from Search import *

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="Dahoas/full-hh-rlhf")
parser.add_argument("--split", type=str, default="test")

parser.add_argument("--Reward_model", default='RLHFlow/RewardModel-Mistral-7B-for-DPA-v1', type=str)
parser.add_argument("--LLM_Decoding", default='RLHFlow/LLaMA3-iterative-DPO-final', type=str)
parser.add_argument("--LLM_Prompt", default='RLHFlow/LLaMA3-iterative-DPO-final', type=str)

parser.add_argument("--Max_Prompt_length", type=int, default=500)
parser.add_argument("--Max_Node_length", type=int, default=200)
parser.add_argument("--Max_Response_length", type=int, default=2000)
parser.add_argument("--sample_num", default=100, type=int)
parser.add_argument("--Sample_Prompt_num", type=int, default=2)
parser.add_argument("--Sample_Node_num", type=int, default=2)


parser.add_argument("--Sample_Original_Prompt", type=int, default=1)

parser.add_argument("--LLM_Decoding_GPU", type=str, default="cuda:0")
parser.add_argument("--LLM_Prompt_GPU", type=str, default="cuda:0")
parser.add_argument("--Reward_model_GPU", type=str, default="cuda:0")

parser.add_argument("--config", default='overall_score',type=str)
parser.add_argument("--out_file",default='./Experiment/LLama3_DPO', type=str)
parser.add_argument("--method", type=str, default='our_method')
parser.add_argument("--backbone", type=str, default="LLama3_DPO")
parser.add_argument("--batch", type=str, default=8)
parser.add_argument("--topk", type=str, default=4)
args = parser.parse_args()

current_file_path = os.path.dirname(__file__)
cfg_path = Path(current_file_path+"/Experiment/"+args.backbone+"/"+args.method+"/config/"+args.config+".jsonl")


with open(cfg_path, "r") as f:
    run_configs = json.load(f)

origin_prompt = []
if args.dataset == "Dahoas/full-hh-rlhf":
    test_ds = load_dataset(args.dataset, split=args.split)
    #\n\nHuman:prompt\n\nAssistant
    test_ds = test_ds["prompt"]

    for prompt in test_ds:
        prompt = prompt.split('Human:', 1)
        prompt = prompt[1].strip()
        prompt = prompt.split('\n\nAssistant:')
        prompt = prompt[0]
        origin_prompt.append(prompt)
elif args.dataset == 'stingning/ultrachat':
    test_ds = load_dataset(args.dataset)
    test_ds = test_ds["train"]["data"]
    origin_prompt = [lst[0] for lst in test_ds]
print(vars(args))
print(run_configs)

search = search_alignment(config=run_configs,
                 method=args.method,
                 LLM_Prompt=args.LLM_Prompt,
                 LLM_Decoding=args.LLM_Decoding,
                 Reward_model=args.Reward_model,
                 Max_Prompt_length=args.Max_Prompt_length,
                 Max_Node_length=args.Max_Node_length,
                 Max_Response_length=args.Max_Response_length,
                 Sample_Prompt_num=args.Sample_Prompt_num,
                 Sample_Node_num = args.Sample_Node_num,
                 Sample_Original_Prompt=args.Sample_Original_Prompt,
                 LLM_Decoding_GPU=args.LLM_Decoding_GPU,
                 LLM_Prompt_GPU=args.LLM_Prompt_GPU,
                 Reward_model_GPU=args.Reward_model_GPU)

print(f"[INFO]: Done")

def sample_response(prompt):
    if args.method == 'our_method':
        answer, revise_prompt, response_set = search.our_method(prompt)
        return answer, revise_prompt, response_set

    elif args.method == 'beam_search_w_RM':
        final_response, prompt = search.beam_search_w_RM(prompt,batch=args.batch,topk=args.topk)
        return final_response


i = 0
data = []
tokenizer_gen = AutoTokenizer.from_pretrained(args.LLM_Decoding)

for idx, ds_row in enumerate(tqdm(origin_prompt)):
    try:
        if i > args.sample_num:
            break
        current_prompt = ds_row
        start = time.time()
        if args.method == 'our_method':
            answer, revise_prompt, response_set = sample_response(current_prompt)
            elapsed = time.time() - start
            model_inputs = [
                {"role": "user",
                 "content": answer},
            ]
            model_inputs = tokenizer_gen.apply_chat_template(model_inputs, return_tensors="pt")
            elapsed = model_inputs.shape[-1] / elapsed
            data.append({"prompt": current_prompt, "revise_prompt": revise_prompt, "result": answer,
                         "response": current_prompt + answer, "elapsed": elapsed, "response_set":response_set,
                         "method": args.method})

        elif args.method == 'beam_search_w_RM':
            answer = sample_response(current_prompt)
            elapsed = time.time() - start
            model_inputs = [
                {"role": "user",
                 "content": answer},
            ]
            model_inputs = tokenizer_gen.apply_chat_template(model_inputs, return_tensors="pt")
            elapsed = model_inputs.shape[-1] / elapsed
            data.append({"prompt": current_prompt, "result": answer, "response": current_prompt + answer, "elapsed": elapsed,"method": args.method})

        
        i += 1
    except:
        continue

data.append({"args": vars(args)})
data.append({"config": run_configs})

out_path = os.path.join(current_file_path, args.out_file, args.method, "Decoding_results", str(args.dataset).replace('/','_')+args.config+"_Sample_Prompt_num_"+str(args.Sample_Prompt_num)+"_Sample_Node_num_"+str(args.Sample_Node_num)+"_Sample_Original_Prompt_"+str(args.Sample_Original_Prompt)+"_Max_Node_length_"+str(args.Max_Node_length))
if args.method == 'beam_search_w_RM':
    out_path = os.path.join(current_file_path, args.out_file, args.method, "Decoding_results",
                            str(args.dataset).replace('/', '_') + "_batch_" + str(args.batch))
with open(Path(out_path + ".jsonl"), "w") as outfile:
    json.dump(data, outfile, ensure_ascii=False, indent=4)
