import json
import re

from datasets import load_dataset
from tqdm import tqdm

from evaluation.schemas import OpenEndedTaskBase, GeneratorModelBase
from evaluation.tasks.generation.wildbench.fastchat_conversation import HF_Conversation, map_to_conv
from src.dispatch_openai_requests import dispatch_batch_openai_requests_with_retry, dispatch_openai_chat_requests


class WildBenchScore(OpenEndedTaskBase):
    def __init__(self):
        super().__init__()
        self.max_new_tokens, self.temperature, self.top_p = self.prepare_hparams()

    def prepare_hparams(self):
        # eval task specified arguments
        max_new_tokens = 4096
        temperature = 0.0
        top_p = 1.0
        return max_new_tokens, temperature, top_p

    def load_metadata(self) -> dict:
        dataset = load_dataset("allenai/WildBench", "v2", split="test")
        metadata = {"session_id": [], "primary_tag": []}

        for ind, item in enumerate(dataset):
            for key in metadata:
                assert key in item, f"Key {key} not found in metadata"
                metadata[key].append(item[key])

        return metadata

    def load(self) -> list[dict]:
        dataset = load_dataset("allenai/WildBench", "v2", split="test")
        chat_history = []
        id_strs = []

        for ind, item in enumerate(dataset):
            assert item["conversation_input"][-1]["role"] == "user"
            extracted_chats = [chat["content"] for chat in item["conversation_input"]]
            chat_history.append(extracted_chats)
            id_strs.append(item["session_id"])

        data = [{"chat_history": hist, "id": id_str} for hist, id_str in zip(chat_history, id_strs)]
        return data

    def clear_output(self, output: str) -> str:
        """
        You can customize the output clearing logic here based on the model_name.
        """
        output = output.replace("<|endoftext|>", " ")
        output = output.replace("</s>", " ")
        output = output.strip()
        return output

    def _evaluate(self, model: GeneratorModelBase) -> tuple[dict, list[dict]]:
        wildbench_data = self.data
        metadata = self.load_metadata()
        model_inputs = []
        conv = None

        # use FastChat-based conversation template
        for data in wildbench_data:
            chats = data["chat_history"]
            if conv is None or not isinstance(conv, HF_Conversation):
                conv = map_to_conv(model.model_input_template_path_or_name)
            else:
                conv.clear()

            for chat_id, chat in enumerate(chats):
                conv.append_message(conv.roles[chat_id % 2], chat)
            conv.append_message(conv.roles[1], None)
            model_inputs.append({"instruction": conv.get_prompt()})

        # adhoc: patch model.model_input_template to "{instruction}"
        model_input_template = model.model_input_template
        model.model_input_template = "{instruction}"
        completions = model.generate_completions(model_inputs,
                                                 repetition_penalty=1.0,
                                                 n=1,
                                                 max_new_tokens=self.max_new_tokens,
                                                 temperature=self.temperature,
                                                 top_p=self.top_p)
        model.model_input_template = model_input_template
        assert len(completions) == len(self.data)

        target_model_data = [{
            "session_id": data["id"],
            "output": self.clear_output(completion),
            "generator": model.model_name_or_path,
        } for data, completion in zip(wildbench_data, completions)]
        for i, d in enumerate(target_model_data):
            for k in metadata:
                d[k] = metadata[k][i]

        HF_BENCH_PATH = "allenai/WildBench"
        HF_BENCH_CONFIG = "v2"
        bench_data = load_dataset(HF_BENCH_PATH, HF_BENCH_CONFIG, split="test")

        ref_model_data = [None] * len(target_model_data)
        histories = []
        last_queries = []
        checklists = []
        for b, t, r in zip(bench_data, target_model_data, ref_model_data):
            compose_eval_item(b, t, r, histories, last_queries, checklists)
        print(f"len(target_model_data)={len(target_model_data)}")
        print(f"len(ref_model_data)={len(ref_model_data)}")
        candidates = list(target_model_data)
        references = list(ref_model_data)
        results = placeholder_generation(candidates, references, histories, last_queries, checklists)
        results = run_eval(results)

        # rescaling scores
        for result in results:
            if result["parsed"] and "score" in result["parsed_result"]:
                result["parsed_result"]["score"] = (int(result["parsed_result"]["score"]) - 5.0) * 20

        assert len(results) == len(target_model_data)
        # put primary_tag in the results
        for i, res in enumerate(results):
            results[i]["id"] = target_model_data[i]["session_id"]
            for k in metadata:
                results[i][k] = metadata[k][i]

        results_parse_success = [result for result in results if result["parsed"]
                                 and "score" in result["parsed_result"]]

        primary_tags_in_results = set([result["primary_tag"] for result in results_parse_success])
        avg_score_per_primary_tag = {}
        for primary_tag in primary_tags_in_results:
            scores = [result["parsed_result"]["score"] for result in results_parse_success
                      if result["primary_tag"] == primary_tag]
            avg_score_per_primary_tag[primary_tag] = sum(scores) / len(scores)
        all_scores = [result["parsed_result"]["score"] for result in results_parse_success]
        avg_score = sum(all_scores) / len(all_scores)
        avg_score_per_primary_tag["all"] = avg_score
        return avg_score_per_primary_tag, results

    def required_input_fields(self) -> list[str]:
        """
        Get the required input fields for the task.

        Returns:
            List of required input fields.
        """
        return ["instruction"]


def extract_values_from_json(json_string, keys=["score", "strengths", "weaknesses", "choice"], allow_no_quotes=False):
    extracted_values = {}
    for key in keys:
        if key not in json_string:
            continue
        # Create a regular expression pattern to find the value for the given key
        pattern = f'"{key}"\\s*:\\s*"([^"]*?)"'
        match = re.search(pattern, json_string)
        if match:
            extracted_values[key] = match.group(1)
        else:
            # Handle the case where the value might contain broken quotes
            pattern = f'"{key}"\\s*:\\s*"(.*?)"'
            match = re.search(pattern, json_string, re.DOTALL)
            if match:
                extracted_values[key] = match.group(1)
        if not match and allow_no_quotes:
            # to allow no quotes on the values
            pattern = f'"{key}"\\s*:\\s*([^,\\s]*)'
            match = re.search(pattern, json_string)
            if match:
                extracted_values[key] = match.group(1)
            else:
                # to allow no quotes on the keys
                pattern = f'{key}\\s*:\\s*([^,\\s]*)'
                match = re.search(pattern, json_string)
                if match:
                    extracted_values[key] = match.group(1)
    return extracted_values


def parse_result(result_str, mode="json", eval_mode="score"):
    assert eval_mode in ["score", "pairwise"]
    result_str = result_str.strip()
    try:
        # result_str = result_str.replace(".\n", ". ")
        # result_str = result_str.replace(".\n\n", ". ")
        # result_str = result_str.replace("\n", " ")
        try:
            parsed_result = json.loads(result_str)
        except:
            parsed_result = extract_values_from_json(result_str, keys=["score", "choice"])
    except Exception as e:
        print(result_str)
        print(e)
        # raise Exception(f"Failed to parse the result: {result_str}")
        parsed_result = {"N/A": "N/A"}
        # exit()
    return parsed_result


def run_eval(results) -> list[dict]:
    batch_size = 16

    for i in tqdm(list(range(0, len(results), batch_size))):
        results_batch = results[i:i + batch_size]
        messages = [[{"role": "system", "content": "You are a helpful AI assistant."},
                     {"role": "user", "content": result["prompt"]}] for result in results_batch]

        outputs = dispatch_batch_openai_requests_with_retry(messages,
                                                            model="gpt-4o-2024-05-13",
                                                            dispatch_func=dispatch_openai_chat_requests,
                                                            temperature=0,
                                                            response_format={"type": "json_object"},
                                                            n=1,
                                                            frequency_penalty=0,
                                                            presence_penalty=0,
                                                            max_tokens=1024)
        for ind, output in enumerate(outputs):
            try:
                results[i + ind]["result"] = output.choices[0].message.content.strip()
                results[i + ind]["parsed_result"] = parse_result(results[i + ind]["result"])
                results[i + ind]["parsed"] = True if results[i + ind]["parsed_result"] is not None else False
                results[i + ind]["error"] = "N/A"
            except Exception as e:
                results[i + ind]["error"] = str(e)
                results[i + ind]["result"] = ""
                results[i + ind]["parsed_result"] = {"choice": "N/A"}
                results[i + ind]["parsed"] = False
    return results


def shorten(text, K=-1):
    if K > 0 and len(text.split(" ")) > K:
        text = " ".join(text.split(" ")[:K]) + "... (truncated)"
    return text


def placeholder_generation(candidates, references, histories, last_queries, checklists):
    eval_template = """# Instruction 

You are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models. 
We will provide you with the user query and an AI-generated responses.
You should first read the user query and the conversation history carefully for analyzing the task, and then evaluate the quality of the responses based on and rules provided below.

# Conversation between User and AI

## History
<|begin_of_history|>

{$history}

<|end_of_history|> 

## Current User Query
<|begin_of_query|>

{$user_query}

<|end_of_query|>

## AI Response
<|begin_of_response|>

{$model_output}

<|end_of_response|>
 

# Evaluation   

## Checklist 

<|begin_of_checklist|>

{$checklist}

<|end_of_checklist|>

Please use this checklist to guide your evaluation, but do not limit your assessment to the checklist.

## Rules 

You should compare the above response based on your analysis of the user queries and the conversation history.
You should first write down your analysis and the checklist that you used for the evaluation, and then provide your assessment according to the checklist.
The scores are in the range of 1~10, where 1 means the response is very poor and 10 means the response is perfect.
Here are more detailed criteria for the scores:

- Score 1~2: The response is very poor and does not make sense at all.
- Score 3~4: The response is poor and does help user solve the problem in a meaningful way.
- Score 5~6: The response is fair but has some issues (e.g., factual errors, hallucinations, missing key information).
- Score 7~8: The response is good enough but could be improved in some ways.
- Score 9~10: The response is perfect and provides helpful information that can help user solve the problem.

## Output Format 
First, please output your analysis for the model response, and then summarize your assessment to two aspects: "strengths" and "weaknesses"; Finally, please write down your rating for the assessment.

Please provide your evaluation results in the following json format by filling in the placeholders in []:
```
{
    "strengths": "[analysis for the strengths of the response]",
    "weaknesses": "[analysis for the weaknesses of the response]",
    "score": "[1~10]"
}
```"""
    assert len(candidates) == len(references)

    L = len(candidates)

    results = []
    for item, ref_item, history, last_query, checklist in zip(candidates, references, histories, last_queries,
                                                              checklists):
        # print(item, ref_item, history, last_query, checklist)
        o = item["output"][0] if type(item["output"]) == list else item["output"]

        # random decide which is A and which is B
        d = {}
        d["session_id"] = item["session_id"]
        d["history"] = history
        d["last_query"] = last_query
        d["model_output"] = item["output"]
        d["generator"] = item["generator"]
        d["eval_config"] = {"mode": "score", "gpt": "gpt-4o-2024-05-13", "max_words": -1}

        prompt = eval_template
        prompt = prompt.replace("{$history}", shorten(history, -1))
        prompt = prompt.replace("{$user_query}", shorten(last_query, -1))
        prompt = prompt.replace("{$model_output}", shorten(o, -1))
        checklist_mardkdown = ""
        for checklist_item in checklist:
            checklist_mardkdown += f"- {checklist_item}\n"
        prompt = prompt.replace("{$checklist}", checklist_mardkdown)
        d_copy = d.copy()
        d_copy["prompt"] = prompt
        if o.strip() == "":
            d_copy["result"] = json.dumps(
                {"strengths": "N/A", "weaknesses": "The model output is empty.", "score": "1"})
        else:
            d_copy["result"] = "N/A"
        results.append(d_copy)

    return results


def compose_eval_item(b, t, r, histories, last_queries, checklists):
    if r is not None:
        assert b["session_id"] == t["session_id"] == r["session_id"]
    else:
        assert b["session_id"] == t["session_id"]
    history = ""
    checklist = b["checklist"]
    if len(b["conversation_input"]) > 0:
        for x in b["conversation_input"][:-1]:
            if x["role"] == "user":
                history += "USER: " + x["content"] + "\n\n"
            elif x["role"] == "assistant":
                history += "ASSISTANT: " + x["content"] + "\n\n"
    last_query = b["conversation_input"][-1]["content"]
    histories.append(history)
    last_queries.append(last_query)
    checklists.append(checklist)
