# Referencing https://github.com/lm-sys/FastChat/
import ast
import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import openai
import shortuuid
import torch
from fastchat.llm_judge.common import (
    load_model_answers,
    load_judge_prompts,
    Judge,
    MatchSingle,
    NEED_REF_CATS,
)
from fastchat.llm_judge.common import temperature_config
from fastchat.model import load_model, get_conversation_template
from openai import OpenAI
from tqdm import tqdm

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase

# API setting constants
API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")


def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None):
    output = API_ERROR_OUTPUT
    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
    for _ in range(API_MAX_RETRY):
        try:
            messages = conv.to_openai_api_messages()
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                n=1,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            output = response.choices[0].message.content
            break
        except (openai.APIConnectionError, openai.RateLimitError) as e:
            print(type(e), e)
            time.sleep(API_RETRY_SLEEP)

    return output


def run_judge_single(question, answer, judge, ref_answer, multi_turn=False):
    kwargs = {}
    model = judge.model_name
    if ref_answer is not None:
        kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0]
        if multi_turn:
            kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1]

    if multi_turn:
        user_prompt = judge.prompt_template["prompt_template"].format(
            question_1=question["turns"][0],
            question_2=question["turns"][1],
            answer_1=answer["choices"][0]["turns"][0],
            answer_2=answer["choices"][0]["turns"][1],
            **kwargs,
        )
    else:
        user_prompt = judge.prompt_template["prompt_template"].format(
            question=question["turns"][0],
            answer=answer["choices"][0]["turns"][0],
            **kwargs,
        )

    rating = -1

    system_prompt = judge.prompt_template["system_prompt"]
    conv = get_conversation_template(model)
    conv.set_system_message(system_prompt)
    conv.append_message(conv.roles[0], user_prompt)
    conv.append_message(conv.roles[1], None)

    judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048)

    if judge.prompt_template["output_format"] == "[[rating]]":
        match = re.search(one_score_pattern, judgment)
        if not match:
            match = re.search(one_score_pattern_backup, judgment)

        if match:
            rating = ast.literal_eval(match.groups()[0])
        else:
            rating = -1
    else:
        raise ValueError(
            f"invalid output format: {judge.prompt_template['output_format']}"
        )

    return rating, user_prompt, judgment


def play_a_match_single(match: MatchSingle, output_file: str):
    question, model, answer, judge, ref_answer, multi_turn = (
        match.question,
        match.model,
        match.answer,
        match.judge,
        match.ref_answer,
        match.multi_turn,
    )

    if judge.prompt_template["type"] == "single":
        score, user_prompt, judgment = run_judge_single(
            question, answer, judge, ref_answer, multi_turn=multi_turn
        )

        question_id = question["question_id"]
        turn = 1 if not multi_turn else 2
        result = {
            "id": question["id"],
            "question_id": question_id,
            "model": model,
            "judge": (judge.model_name, judge.prompt_template["name"]),
            "user_prompt": user_prompt,
            "judgment": judgment,
            "score": score,
            "turn": turn,
            "tstamp": time.time(),
        }
        print(
            f"question: {question_id}, turn: {turn}, model: {model}, "
            f"score: {score}, "
            f"judge: {(judge.model_name, judge.prompt_template['name'])}"
        )
    else:
        raise ValueError(f"invalid judge type: {judge['type']}")

    if output_file:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, "a") as fout:
            fout.write(json.dumps(result) + "\n")

    return result


def load_questions(question_file: str) -> list[dict]:
    """Load questions from a file."""
    questions = []
    with open(question_file, "r") as ques_file:
        for line in ques_file:
            if line:
                questions.append(json.loads(line))
    return questions


class MTBench(OpenEndedTaskBase):
    def required_input_fields(self) -> list[str]:
        # since this is using the MT-bench official format, it doesn't follow the structure here.
        return []

    def __init__(self):
        super().__init__()
        self.max_new_tokens, self.temperature, self.top_p = self.prepare_hparams()
        self.temperature_config = {
            "writing": 0.7,
            "roleplay": 0.7,
            "extraction": 0.0,
            "math": 0.0,
            "coding": 0.0,
            "reasoning": 0.0,
            "stem": 0.1,
            "humanities": 0.1,
            "arena-hard-200": 0.0,
        }

    def prepare_hparams(self):
        # eval task specified arguments
        max_new_tokens = 1024
        temperature = -1  # not predefined, because it depends on it's category
        top_p = 1.0
        return max_new_tokens, temperature, top_p

    def load(self) -> list[dict]:
        """
        Load the task data. It will load data to self.data.

        Returns:
            List of task data.
        """
        current_path = os.path.dirname(__file__)
        question_file = os.path.join(current_path, "question.jsonl")
        questions = load_questions(question_file)
        return questions

    @torch.inference_mode()
    def get_model_answers(
            self,
            model_path,
            model_id,
            questions,
            answer_file,
            max_new_token,
            num_choices,
            num_gpus_per_model,
            max_gpu_memory,
            dtype,
            revision,
    ):
        model, tokenizer = load_model(
            model_path,
            revision=revision,
            device="cuda",
            num_gpus=num_gpus_per_model,
            max_gpu_memory=max_gpu_memory,
            dtype=dtype,
            load_8bit=False,
            cpu_offloading=False,
            debug=False,
        )

        for question in tqdm(questions):
            if question["category"] in temperature_config:
                temperature = temperature_config[question["category"]]
            else:
                temperature = 0.7

            choices = []
            for i in range(num_choices):
                torch.manual_seed(i)
                conv = get_conversation_template(model_id)
                turns = []
                for j in range(len(question["turns"])):
                    qs = question["turns"][j]
                    conv.append_message(conv.roles[0], qs)
                    conv.append_message(conv.roles[1], None)
                    prompt = conv.get_prompt()
                    input_ids = tokenizer([prompt]).input_ids

                    if temperature < 1e-4:
                        do_sample = False
                    else:
                        do_sample = True

                    # some models may error out when generating long outputs
                    try:
                        output_ids = model.generate(
                            torch.as_tensor(input_ids).cuda(),
                            do_sample=do_sample,
                            temperature=temperature,
                            max_new_tokens=max_new_token,
                            pad_token_id=tokenizer.eos_token_id
                        )
                        if model.config.is_encoder_decoder:
                            output_ids = output_ids[0]
                        else:
                            output_ids = output_ids[0][len(input_ids[0]):]

                        # be consistent with the template's stop_token_ids
                        if conv.stop_token_ids:
                            stop_token_ids_index = [
                                i
                                for i, id in enumerate(output_ids)
                                if id in conv.stop_token_ids
                            ]
                            if len(stop_token_ids_index) > 0:
                                output_ids = output_ids[: stop_token_ids_index[0]]

                        output = tokenizer.decode(
                            output_ids,
                            spaces_between_special_tokens=False,
                        )
                        if conv.stop_str and isinstance(conv.stop_str, list):
                            stop_str_indices = sorted(
                                [
                                    output.find(stop_str)
                                    for stop_str in conv.stop_str
                                    if output.find(stop_str) > 0
                                ]
                            )
                            if len(stop_str_indices) > 0:
                                output = output[: stop_str_indices[0]]
                        elif conv.stop_str and output.find(conv.stop_str) > 0:
                            output = output[: output.find(conv.stop_str)]

                        for special_token in tokenizer.special_tokens_map.values():
                            if isinstance(special_token, list):
                                for special_tok in special_token:
                                    output = output.replace(special_tok, "")
                            else:
                                output = output.replace(special_token, "")
                        output = output.strip()

                        if conv.name == "xgen" and output.startswith("Assistant:"):
                            output = output.replace("Assistant:", "", 1).strip()
                    except RuntimeError as e:
                        print("ERROR question ID: ", question["question_id"])
                        output = "ERROR"

                    conv.update_last_message(output)
                    turns.append(output)

                choices.append({"index": i, "turns": turns})

            # Dump answers
            os.makedirs(os.path.dirname(answer_file), exist_ok=True)
            with open(os.path.expanduser(answer_file), "a") as fout:
                ans_json = {
                    "id": question["id"],
                    "question_id": question["question_id"],
                    "category": question["category"],
                    "answer_id": shortuuid.uuid(),
                    "model_id": model_id,
                    "choices": choices,
                    "tstamp": time.time(),
                }
                fout.write(json.dumps(ans_json) + "\n")

    @staticmethod
    def make_judge_single(judge_model, judge_prompts):
        judges = {}
        judges["default"] = Judge(judge_model, judge_prompts["single-v1"])
        judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True)
        judges["default-mt"] = Judge(
            judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True
        )
        judges["math-mt"] = Judge(
            judge_model,
            judge_prompts["single-math-v1-multi-turn"],
            ref_based=True,
            multi_turn=True,
        )
        return judges

    @staticmethod
    def make_match_single(
            questions,
            models,
            model_answers,
            judge,
            baseline_model=None,
            ref_answers=None,
            multi_turn=False,
    ):
        matches = []
        for q in questions:
            if multi_turn and len(q["turns"]) != 2:
                continue
            for i in range(len(models)):
                q_id = q["question_id"]
                m = models[i]
                a = model_answers[m][q_id]
                if ref_answers is not None:
                    ref = ref_answers[judge.model_name][q_id]
                    matches.append(
                        MatchSingle(
                            dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn
                        )
                    )
                else:
                    matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn))
        return matches

    def _metrics(self, results: list[dict]) -> dict:
        total_score = np.mean([r["score"] for r in results])
        first_turn_score = np.mean([r["score"] for r in results if r["turn"] == 1])
        second_turn_score = np.mean([r["score"] for r in results if r["turn"] == 2])
        categories = list(set([r["category"] for r in results]))
        score_per_category = {
            category: np.mean([r["score"] for r in results if r["category"] == category])
            for category in categories
        }
        return {
            "total_score": total_score,
            "first_turn_score": first_turn_score,
            "second_turn_score": second_turn_score,
            "score_per_category": score_per_category,
        }

    def _judge(self, questions: list[dict], model_id: str) -> tuple[dict, list[dict]]:
        judge_model = "gpt-4"

        output_file = os.path.join(os.path.dirname(__file__), "model_judgment",
                                   f"{model_id}_{judge_model}_single.jsonl")
        os.makedirs(os.path.dirname(output_file), exist_ok=True)

        # check if the model is already evaluated
        if os.path.exists(output_file):
            print(f" *** WARNING: {model_id} already has judgments. now resetting... *** ")
            os.remove(output_file)

        # Load answers
        baseline_model = None
        answer_dir = os.path.join(os.path.dirname(__file__), "model_answer")
        os.makedirs(answer_dir, exist_ok=True)
        ref_answer_dir = os.path.join(os.path.dirname(__file__), "reference_answer")
        os.makedirs(ref_answer_dir, exist_ok=True)

        model_answers = load_model_answers(answer_dir)
        ref_answers = load_model_answers(ref_answer_dir)
        models = [model_id]

        judge_file = os.path.join(os.path.dirname(__file__), "judge_prompts.jsonl")
        judge_prompts = load_judge_prompts(judge_file)
        judges = self.make_judge_single("gpt-4", judge_prompts)
        play_a_match_func = play_a_match_single
        make_match_func = self.make_match_single

        question_math = [q for q in questions if q["category"] in NEED_REF_CATS]
        question_default = [q for q in questions if q["category"] not in NEED_REF_CATS]

        # Make matches
        matches = []
        matches += make_match_func(
            question_default, models, model_answers, judges["default"], baseline_model
        )
        matches += make_match_func(
            question_math,
            models,
            model_answers,
            judges["math"],
            baseline_model,
            ref_answers,
        )
        matches += make_match_func(
            question_default,
            models,
            model_answers,
            judges["default-mt"],
            baseline_model,
            multi_turn=True,
        )
        matches += make_match_func(
            question_math,
            models,
            model_answers,
            judges["math-mt"],
            baseline_model,
            ref_answers,
            multi_turn=True,
        )

        match_stat = {}
        match_stat["mode"] = "single"
        match_stat["judge"] = "gpt-4"
        match_stat["baseline"] = baseline_model
        match_stat["model_list"] = models
        match_stat["total_num_questions"] = len(questions)
        match_stat["total_num_matches"] = len(matches)
        match_stat["output_path"] = output_file

        print("Stats:")
        print(json.dumps(match_stat, indent=4))

        # Play matches
        def play_a_match_wrapper(match):
            play_a_match_func(match, output_file=output_file)

        np.random.seed(0)
        np.random.shuffle(matches)

        with ThreadPoolExecutor(16) as executor:
            for match in tqdm(
                    executor.map(play_a_match_wrapper, matches), total=len(matches)
            ):
                pass

        # load output_file again to get the numbers.
        cat_per_question_id = {
            q["question_id"]: q["category"] for q in questions
        }
        with open(output_file, "r") as fin:
            results = [json.loads(line) for line in fin]
            results = [r for r in results if r["model"] == model_id]
            for r in results:
                r["category"] = cat_per_question_id[r["question_id"]]
                r["id"] = f"{r['id']}/{r['turn']}"

        return self._metrics(results), results

    def _generate(self, model_id: str, model_name_or_path: str, questions: list[dict]):
        completions_output_path = os.path.join(
            os.path.dirname(__file__), "model_answer", f"{model_id}.jsonl"
        )
        if os.path.exists(completions_output_path):
            print(f" *** WARNING: {model_id} already has completions. now resetting... *** ")
            os.remove(completions_output_path)

        self.get_model_answers(
            model_name_or_path,
            model_id,
            questions,
            completions_output_path,
            self.max_new_tokens,
            1,
            1,
            None,
            torch.bfloat16,
            "main",
        )

    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        questions = self.data
        model_id = "_".join(model.model_name_or_path.split("/"))
        self._generate(model_id, model.model_name_or_path, questions)

        return self._judge(questions, model_id)
