from reward_sampling import RewardSampling
from datasets import load_dataset
import argparse
import json
from pathlib import Path
from tqdm import tqdm
import time
import re
import os
from huggingface_hub import login
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
# login()
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, LlamaForCausalLM, LlamaForSequenceClassification

from tqdm import tqdm

parser = argparse.ArgumentParser()
# dataset:Dahoas/full-hh-rlhf;PKU-Alignment/BeaverTails;openbmb/UltraFeedback
parser.add_argument("--dataset", type=str, default="Dahoas/full-hh-rlhf")
parser.add_argument("--split", type=str, default="test")

parser.add_argument("--Reward_model", default='RLHFlow/RewardModel-Mistral-7B-for-DPA-v1', type=str)
parser.add_argument("--LLM_Decoding", default='RLHFlow/LLaMA3-iterative-DPO-final', type=str)

parser.add_argument("--sample_num", default=10, type=int)
parser.add_argument("--max_response_length", default=1000, type=int)

parser.add_argument("--reward_threshold", default=10, type=float)
parser.add_argument("--entropy_threshold", default=1.5, type=float)
parser.add_argument("--topk", default=1, type=int)
parser.add_argument("--beta", default=0.7, type=float)
parser.add_argument("--alpha", default=0.5, type=float)


parser.add_argument("--LLM_Decoding_GPU", type=str, default="cuda:0")
parser.add_argument("--LLM_Prompt_GPU", type=str, default="cuda:0")
parser.add_argument("--Reward_model_GPU", type=str, default="cuda:0")

parser.add_argument("--out_file",default='./Experiment/LLama3_DPO', type=str)
parser.add_argument("--method", type=str, default='ARGS')
parser.add_argument("--backbone", type=str, default="LLama3_DPO")
args = parser.parse_args()

current_file_path = os.path.dirname(__file__)
print(vars(args))
print(current_file_path)
origin_prompt = []
if args.dataset == "Dahoas/full-hh-rlhf":
    test_ds = load_dataset(args.dataset, split=args.split)
    #\n\nHuman:prompt\n\nAssistant
    test_ds = test_ds["prompt"]

    for prompt in test_ds:
        prompt = prompt.split('Human:', 1)
        prompt = prompt[1].strip()
        prompt = prompt.split('\n\nAssistant:')
        prompt = prompt[0]
        origin_prompt.append(prompt)
elif args.dataset == 'stingning/ultrachat':
    test_ds = load_dataset(args.dataset)
    test_ds = test_ds["train"]["data"]
    origin_prompt = [lst[0] for lst in test_ds]

print(vars(args))


print(f"[INFO]: Done")


rs = RewardSampling(access_token=None, llm_dir=args.LLM_Decoding, rm_dir=args.Reward_model)

def sample_response(prompt):
    if args.method == 'CARDS':
        answer = rs.rs_generate(prompt,
                                entropy_threshold=args.entropy_threshold,reward_threshold=args.reward_threshold ,
                                alpha=args.alpha,
                                beta=args.beta,
                                topk=args.topk,
                                max_new_token= args.max_response_length)

        return answer

    elif args.method == 'ARGS':
        answer = rs.args_generate(prompt,
                                topk=args.topk,
                                max_new_token=args.max_response_length)
        return answer

i = 0
data = []
tokenizer_gen = AutoTokenizer.from_pretrained(args.LLM_Decoding)

for idx, ds_row in enumerate(tqdm(origin_prompt)):
    # try:
    if i > args.sample_num:
        break
    current_prompt = ds_row
    start = time.time()
    if args.method == 'CARDS':
        answer  = sample_response(current_prompt)
        elapsed = time.time() - start
        model_inputs = [
            {"role": "user",
             "content": answer},
        ]
        model_inputs = tokenizer_gen.apply_chat_template(model_inputs, return_tensors="pt")
        elapsed = model_inputs.shape[-1] / elapsed
        data.append({"prompt": current_prompt, "result": answer,
                     "response": current_prompt + answer, "elapsed": elapsed,"method": args.method})
    if args.method == 'ARGS':
        answer  = sample_response(current_prompt)
        elapsed = time.time() - start
        model_inputs = [
            {"role": "user",
             "content": answer},
        ]
        model_inputs = tokenizer_gen.apply_chat_template(model_inputs, return_tensors="pt")
        elapsed = model_inputs.shape[-1] / elapsed
        data.append({"prompt": current_prompt, "result": answer,
                     "response": current_prompt + answer, "elapsed": elapsed,"method": args.method})

    i += 1
    # except:
    #     continue

data.append({"args": vars(args)})

out_path = os.path.join(current_file_path, args.out_file, args.method, "Decoding_results", str(args.dataset).replace('/','_')+"_reward_threshold_"+str(args.reward_threshold)+"_entropy_threshold_"+str(args.entropy_threshold)+"_topk_"+str(args.topk)+"_beta_"+str(args.beta)+"_alpha_"+str(args.alpha))
with open(Path(out_path + ".jsonl"), "w") as outfile:
    json.dump(data, outfile, ensure_ascii=False, indent=4)
