from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset
import torch
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import argparse
import time
from accelerate import Accelerator
import openai
import os
import llm_blender
import tqdm


parser = argparse.ArgumentParser()

# parser.add_argument("--total_part", type=int)
# parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--name", type=str)
parser.add_argument("--eval", action="store_true")
# parser.add_argument("--output_left", type=str)


args = parser.parse_args()

# dataset_dir = args.dataset
model_dir = args.model
col_name = args.name
do_eval = args.eval
# output_dir = args.output
# output_left = args.output_left

# model_dir = "YYYYYYibo/zephyr-7b-lora-64-no-quant-6k"

openai.api_key = os.environ['OPENAI_API_KEY']


print(model_dir)

###############
# Load datasets
###############
# parser = H4ArgumentParser((ModelArguments, DataArguments, RDPOConfig))
# model_args, data_args, training_args = parser.parse()
# print(model_args)
# print(data_args)
# print(training_args)

# model = AutoModelForCausalLM.from_pretrained(
#     model_args.model_name_or_path,  torch_dtype=torch.bfloat16)
# model = LLM("alignment-handbook/zephyr-7b-sft-full",
#             enable_lora=True, max_lora_rank=64)

# ref_model = AutoModelForCausalLM.from_pretrained(
#     "alignment-handbook/zephyr-7b-sft-full",  torch_dtype=torch.bfloat16)

# tokenizer = AutoTokenizer.from_pretrained(model_dir)


def generate_response_vllm(dataset, col_name):
    model = LLM("alignment-handbook/zephyr-7b-sft-full",
                enable_lora=True, max_lora_rank=64)
    tokenizer = AutoTokenizer.from_pretrained(
        "alignment-handbook/zephyr-7b-sft-full")
    with torch.inference_mode():
        sampling_params = SamplingParams(
            max_tokens=1024,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
        )
        chosen_messages = dataset['chosen']
        chat_prompts = []
        counter = 0
        iter = 0
        for chosen_message in chosen_messages:
            iter += 1
            prompt_message = chosen_message[:-1]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)
        if model_dir == "original":
            responses = model.generate(chat_prompts, sampling_params)
        else:
            lora_path = snapshot_download(repo_id=model_dir)
            responses = model.generate(chat_prompts, sampling_params, lora_request=LoRARequest(
                "sql_adapter", 2, lora_path))
        responses = [response.outputs[0].text for response in responses]
    dataset = dataset.add_column(col_name, responses)
    return dataset


def get_eval(sys_prompt, user_prompt):
    try_num = 0
    while try_num < 10:
        try:
            response = openai.ChatCompletion.create(**{
                "model": "gpt-4",
                "messages": [
                    {"role": "system", "content": sys_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                "temperature": 0,
                "max_tokens": 1024,
                "top_p": 0.6,
                "presence_penalty": 0,
                "frequency_penalty": 0
            })
            return response["choices"][0]["message"]["content"].strip()
        except KeyboardInterrupt as e:
            raise e
        except Exception as e:
            print(e)
            pass
    raise Exception("API Error")


def get_score(prompt, answer):
    try_num = 0

    system_prompt = "A chat between a curious user and an artificial intelligence expert. The expert gives helpful, specific, and concise answers to the user's questions."

    feedback_prompt = \
        """Given my answer to an instruction, your role is to provide specific and constructive feedback for me. You should find the best way for me to learn from your feedback and improve my performance.

    You should consider multiple aspects of my answer, including helpfulness, truthfulness, honesty, and to what extent the answer follows instructions.
    ---

    ### Instruction
    {instruction}

    ### Answer
    {completion}
    ---

    Please act as a teacher and provide specific and constructive feedback. Besides describing the weaknesses of the answer, you should also provide specific suggestions to guide me toward understanding how to improve. Please note, however, that your suggestions should help me better complete the instructions, but you should not introduce new requirements that are not mentioned in the instructions. Your feedback should focus on enhancing my ability to think critically and respond accurately. However, never explicitly provide the reference answer, nor do polite phrases be required. Only respond with concise feedback in chat style. Finally, score the overall quality of the answer from 1 to 10, where 1 is the worst and 10 is the best.

    You should follow this format:
    *Format*
    ### Feedback
    [Your feedback]
    Overall Score: [1-10]

    ---

    ### Feedback
    """
    while try_num < 4:
        try:
            response = get_eval(system_prompt, feedback_prompt.format(
                instruction=prompt, completion=answer))
            response = response.split("\nOverall Score: ")
            assert len(response) == 2
            critique, score = response[0].strip(), response[1].split(".")[
                0].strip()
            score = score if "/" not in score else (
                eval(score.split("/")[0].strip()))
            final_score = int(score)
            assert 1 <= final_score <= 10
            return final_score
        except Exception as e:
            print("error:", e)
    raise Exception("API Error: not an integer score")


def eval_dataset_gpt(dataset, col_name):
    before_list = []
    after_list = []
    for data in dataset:
        # before_score = get_score(data["prompt"], data["original"])
        after_score = get_score(data["prompt"], data[col_name])
        # before_list.append(before_score)
        after_list.append(after_score)
        print("score:", after_score)
        print("average:", sum(after_list)/len(after_list))
    print("="*80)
    print(col_name)
    # print("before average score:", sum(before_list)/len(before_list))
    print("after average score:", sum(after_list)/len(after_list))
    print("="*80)

    # dataset = dataset.add_column("before_score", before_score)
    dataset = dataset.add_column("gpt_score", after_list)
    return dataset


@torch.no_grad()
def eval_dataset_pairrm(dataset, col_name):
    blender = llm_blender.Blender()
    blender.loadranker("llm-blender/PairRM")
    with torch.inference_mode():
        prompts = dataset["prompt"]
        # response_index_list = ["chosen", "rejected", "max_pi", "random"]
        # chosen_list = [row[1]["content"] for row in dataset["chosen"]]
        # rejected_list = [row[1]["content"] for row in dataset["rejected"]]
        # opt_list = dataset["minpi"]
        # random_list = dataset["random"]
        before_list = dataset["original"]
        after_list = dataset[col_name]
        ds_size = len(prompts)
        candidates_texts = [[before_list[idx]] + [after_list[idx]]
                            for idx in range(ds_size)]
        rank = blender.rank(prompts, candidates_texts, return_scores=False)

        chosen_indices = np.argmin(rank, axis=1)
        rejected_indices = np.argmax(rank, axis=1)
        win_rate = np.sum(chosen_indices == 1)/len(chosen_indices)
        # print(chosen_indices)
        print(col_name, win_rate)

    return dataset


if __name__ == "__main__":
    print(do_eval)
    if model_dir == "original":
        train_dataset = load_dataset("YYYYYYibo/ultrafeedback_binarized_with_response_full_part1", split="train_prefs",
                                     download_mode="force_redownload", ignore_verifications=True).select(range(200))
        train_dataset = generate_response_vllm(train_dataset, col_name)
        train_dataset = eval_dataset_gpt(train_dataset, col_name)
        train_dataset.push_to_hub(
            "YYYYYYibo/eval-dataset-original", split="train_prefs", private=False)
    else:
        train_dataset = load_dataset("YYYYYYibo/ultrafeedback_binarized_with_response_full_part1", split="train_prefs",
                                     download_mode="force_redownload", ignore_verifications=True).select(range(200))
        train_dataset = generate_response_vllm(train_dataset, col_name)
        train_dataset = eval_dataset_gpt(train_dataset, col_name)
        # if do_eval:
        #     train_dataset = load_dataset("YYYYYYibo/eval-dataset-"+col_name, split="train_prefs",
        #                                  download_mode="force_redownload", ignore_verifications=True).select(range(0, 200))
        #     train_dataset = eval_dataset_gpt(train_dataset, col_name)
        #     train_dataset.push_to_hub(
        #         "YYYYYYibo/eval-dataset-with-score-"+col_name, split="train_prefs", private=False)

        # else:
        #     train_dataset = load_dataset("YYYYYYibo/eval-dataset-original", split="train_prefs",
        #                                  download_mode="force_redownload", ignore_verifications=True)
        #     train_dataset = generate_response_vllm(train_dataset, col_name)
        train_dataset.push_to_hub(
            "YYYYYYibo/eval-dataset-"+col_name, split="train_prefs", private=False)
