import os
import sys

import argparse
import json
import time
import openai
from tqdm import tqdm
import os


# API setting constants
API_MAX_RETRY = 16
API_RETRY_SLEEP = 0.8
API_ERROR_OUTPUT = "$ERROR$"

# Sampling temperature configs for
temperature_config = {
    "writing": 0.7,
    "roleplay": 0.7,
    "extraction": 0.0,
    "math": 0.0,
    "coding": 0.0,
    "reasoning": 0.0,
    "stem": 0.1,
    "humanities": 0.1,
}

reverse_model_map = {
    "model_1": "model_2",
    "model_2": "model_1",
}

def read_jsonl(train_fn):
    res = []
    evaluator = set()
    with open(train_fn) as f:
        for i, line in enumerate(f):
            try:
                sample = json.loads(line)
                evaluator.add(sample["evaluation"])
                res.append(sample)
            except:
                continue
    if "rouge" in evaluator or "f1" in evaluator:
        res = [sample for sample in res if sample["evaluation"] =="rouge" or sample["evaluation"] == "f1"]
    print(f"loading from {train_fn}, there are {len(res)} samples")
    return res


def write_jsonl(data, fn):
    with open(fn, "w") as f:
        for line in data:
            print(json.dumps(line), file=f)

def play_a_match_pair(sample, output_file: str, model_1, model_2):
    global start_idx
    answer_1 = sample[model_1]
    answer_2 = sample[model_2]
    question = sample["query"]
    model_1 = model_1.replace("_pred", '')
    model_2 = model_2.replace("_pred", '')
    ref_answer = sample["gt"]

    g1_winner, g1_user_prompt, g1_judgment = run_judge_pair(
        question, answer_1, answer_2, judge_prompts, ref_answer
    )

    start_idx += 0.5
    g2_winner, g2_user_prompt, g2_judgment = run_judge_pair(
        question, answer_2, answer_1, judge_prompts, ref_answer)
    start_idx += 0.5

    g1_map = {"A": "model_1", "B": "model_2"}
    g2_map = {"A": "model_2", "B": "model_1"}
    g1_winner = g1_map.get(g1_winner, g1_winner)
    g2_winner = g2_map.get(g2_winner, g2_winner)

    covert = {"model_1": model_1, "model_2": model_2, "tie":"tie", "error":"error"}
    result = {
        "g1_winner": covert[g1_winner],
        "g2_winner": covert[g2_winner],
        "judge": args.judge_model,
        "g1_user_prompt": g1_user_prompt,
        "g1_judgment": g1_judgment,
        "g2_user_prompt": g2_user_prompt,
        "g2_judgment": g2_judgment,
        "tstamp": time.time(),
    }
    ret = [ covert[g1_winner], covert[g2_winner]]
    print("winner of this assessment:", covert[g1_winner],"(before swap)" ,covert[g2_winner],"(after swap)")

    if output_file:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, "a") as fout:
            fout.write(json.dumps(result) + "\n")

    return ret


def run_judge_pair(question, answer_a, answer_b, judge, ref_answer):
    model = judge["judge_model"]
    system_prompt = judge["system_prompt"]
    user_prompt = judge["prompt_template"].format(
        question=question,
        answer_a=answer_a,
        answer_b=answer_b,
        ref_answer=ref_answer
    )

    if model in ["gpt-3.5-turbo", "gpt-4"]:
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        if start_idx < 1:
            print("=" * 30)
            if start_idx == 0.5:
                print("after swapping the answers:")
            else:
                print(f"An example of input prompt:")
            print("=" * 30)
            print(messages[1]["content"])
            print("------------------ end of user prompt ---------------")
            input("Press Enter to confirm...")

        judgment = chat_compeletion_openai(model, messages, temperature=0, max_tokens=2048)
    else:
        raise ValueError(f"Invalid judge model name: {model}")

    if judge["output_format"] == "[[A]]":
        if "[[A]]" in judgment:
            winner = "A"
        elif "[[B]]" in judgment:
            winner = "B"
        elif "[[C]]" in judgment:
            winner = "tie"
        else:
            winner = "error"
    else:
        raise ValueError(
            f"invalid output format: {judge['output_format']}"
        )
    if start_idx < 1:
        print(judgment)

    return winner, user_prompt, judgment


def chat_compeletion_openai(model, messages, temperature, max_tokens):
    output = API_ERROR_OUTPUT
    for _ in range(API_MAX_RETRY):
        try:
            response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                n=1,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            output = response["choices"][0]["message"]["content"]
            break
        except openai.error.OpenAIError as e:
            print(type(e), e)
            time.sleep(API_RETRY_SLEEP)

    return output

def calculate_win_rate():
    print("drop", statistics["error"], " error output")
    statistics.pop("error")
    total = sum([statistics[key] for key in statistics])
    score_of_tie = 0.5 * statistics["tie"]
    statistics.pop("tie")
    for key in statistics:
        print(f"win rate of {key} = {(score_of_tie+statistics[key])/total}")
        print(f"win rate of {key} = {(score_of_tie + statistics[key]) / total}", file=f_out)


def construct_pairs(pred_file, battle_file):
    pred_samples = read_jsonl(pred_file) # key :   query,   gt (ground truth), <model_name>_pred
    battle_samples = read_jsonl(battle_file)# key :   query,   gt (ground truth), <model_name>_pred, <turbo-16k-0613>_pred
    print(len(pred_samples), len(battle_samples))

    assert len(pred_samples) == len(battle_samples)
    pred_key = ""
    battle_key = ""
    for key in pred_samples[0]:
        if "_pred" in key:
            pred_key = key
            break
    for key in battle_samples[0]:
        if "_pred" in key:
            battle_key = key
    assert "_pred" in pred_key
    assert "_pred" in battle_key
    print("model name of pred file:", pred_key)
    print("model name of battle file", battle_key)

    battle_key2pred = {}
    for sample in battle_samples:
        battle_key2pred[sample["query"]+sample["gt"]] = sample[battle_key]
    paired_samples = []
    for sample in pred_samples:
        gt = str(sample["gt"])
        query = str(sample["query"])
        if query+gt not in battle_key2pred:
            continue
        sample[battle_key] = battle_key2pred[query+gt]

        paired_samples.append(sample)
    return paired_samples, pred_key, battle_key

if __name__ == "__main__":

    # gpt4/3.5 key
    openai.api_key = "" # please fulfill your api key here

    parser = argparse.ArgumentParser()
    parser.add_argument("--judge_model", type=str, default="gpt-4", choices=["gpt-4", "gpt-3.5-turbo"])
    parser.add_argument('--pred_file', default="Predictions/llm_gpt4_eval/<model_name>.pred.jsonl")
    parser.add_argument('--battle_with', default="turbo-16k-0613")

    args = parser.parse_args()

    judge_prompts = {"judge_model": args.judge_model, "name": "pair-v2", "type": "pairwise",
                              "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question about the content of a long document.  You will be given a reference answer written by human, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Additional details or information that are not mentioned in reference answer cannot be considered as advantages and do not let them sway your judgment. Your evaluation should also consider the relevance to user's question but it is more important to avoid factual errors according to the reference answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
                              "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "prompt without long documents as input", "category": "general", "output_format": "[[A]]" }

    # Load questions
    output_file = args.pred_file.replace("jsonl", "judge.jsonl")
    samples, model_1, model_2 = construct_pairs(args.pred_file, args.battle_with)
    # Show match stats and prompt enter to continue
    start_idx = 0
    statistics = {"tie":0, "error":0}
    for sample in tqdm(samples):
        res = play_a_match_pair(sample, output_file, model_1, model_2)
        for output in res:
            if output not in statistics:
                statistics[output] = 1
            else:
                statistics[output] += 1

    print(f"among {len(samples)} samples:")
    for key in statistics:
        if key != "tie" and key != "error":
            print(f"n_wins of {key} = {statistics[key]}")
    f_out =  open(output_file, "a")
    print(f"n_draws =", statistics["tie"], "n_error =", statistics["error"])
    print(f"n_draws =", statistics["tie"], "n_error =", statistics["error"], file=f_out)
    calculate_win_rate()
    print("Details of the output during evaluation are saved in", output_file)
