import json
from pathlib import Path
import re
import asyncio
import nest_asyncio

DABENCH = "You are required to solve the problem within a CSV file named {file_name}. \n**Problem**: {question} \n**Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. \n **Output Format**: The output format should be {format}"
DABENCH_PATH = "./di_dataset/InfiAgent/examples/DA-Agent/data"


# This code is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
def evaluate_accuracy_by_question(results):
    correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
    total = len(results)
    return round(correct / total, 4) if total > 0 else 0


def evaluate_accuracy_by_sub_question(results):
    correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
    total = sum(len(result["correctness"]) for result in results if "correctness" in result)
    return round(correct / total, 4) if total > 0 else 0


def evaluate_accuracy_proportional_by_sub_question_adjusted(results):
    total_score = 0
    for result in results:
        if "correctness" in result:
            sub_question_count = len(result["correctness"])
            score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0
            question_score = sum(result["correctness"].values()) * score_per_sub_question
            total_score += question_score
    return round(total_score / len(results), 4) if results else 0


class DABench:
    def __init__(
        self,
        questions_file=Path(DABENCH_PATH) / "da-dev-questions.jsonl",
        answers_file=Path(DABENCH_PATH) / "da-dev-labels.jsonl",
        template="",
    ):
        # Read questions from a JSONL file
        with open(questions_file, "r") as file:
            self.questions = {int(json.loads(line)["id"]): json.loads(line) for line in file}

        # Read answers from a JSONL file
        with open(answers_file, "r") as file:
            self.answers = {int(json.loads(line)["id"]): json.loads(line) for line in file}

        self.template = template if template else DABENCH

    def get_question(self, question_id):
        """Retrieve the question by its id."""
        return self.questions.get(question_id, "Question not found.")

    def get_prompt(self, question_id):
        """Retrieve the question by its id."""
        temp = self.get_question(question_id)
        return self.template.format(
            question=temp["question"],
            constraints=temp["constraints"],
            format=temp["format"],
            file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
            level=temp["level"],
        )

    def get_answer(self, answer_id):
        """Retrieve the answer list by its id."""
        return self.answers.get(answer_id, "Answer not found.")

    def eval(self, id: str, prediction: str) -> bool:
        """Evaluate the prediction against the true label."""
        true_label = self.get_answer(id)["common_answers"]
        nest_asyncio.apply()
        cleaned_prediction = prediction.replace('{', '').replace('}', '').replace("'", "")
        if cleaned_prediction:  # Ensure it's not empty
            try:
                pred_dict = parse_prediction(cleaned_prediction)
                if compare_predictions(pred_dict, true_label):
                    return (prediction, True)
            except:
                print(f"format errer, using gpt to refomat")
        
        # If the cleaned prediction is not valid, try the async reformat
        try:
            prediction = asyncio.run(reformat(self.get_question(id)["question"], self.get_question(id)["format"], prediction))
            # prediction = prediction.split("Answer{{")[1].split("}}")[0].strip()
            _prediction = prediction.replace("\Answer{","").strip()
            pred_dict = parse_prediction(_prediction)
            if compare_predictions(pred_dict, true_label):
                return (_prediction, True)
        except Exception as e:
            _prediction = ""
            # print(f"prediction: {prediction}, {_prediction}, true_label: {true_label}")
            print(f"Error during async reformat: {e}")
            # Skip this step if there's an error

        return (_prediction, False)


    def eval_all(self, id_list, predictions):
        """Evaluate all predictions and calculate accuracy rates."""
        def sigle_eval(id, prediction):
            """Evaluate the prediction against the true label for a single question and return a dictionary indicating the correctness of each metric."""
            true_label = self.get_answer(id)["common_answers"]
            prediction = prediction.replace('{', '').replace('}', '').replace("'", "")
            pred_dict = parse_prediction(prediction)
            # Initialize the correctness dictionary with False values
            correctness = {metric: False for metric, _ in true_label}
            # Check each metric's prediction against the true label
            for metric, true_value in true_label:
                try:
                    true_value = float(true_value)
                except:
                    true_value = true_value.replace(',', '')
                if metric in pred_dict:
                    # Consider the prediction correct if it's within a small tolerance
                    if isinstance(true_value, (int, float)) and isinstance(pred_dict[metric], (int, float)) and abs(pred_dict[metric] - true_value) < 1e-6:
                        correctness[metric] = True
                    if isinstance(true_value, str) and str(pred_dict[metric]) == str(true_value):
                        correctness[metric] = True
            return correctness

        results = []
        for id, prediction in zip(id_list, predictions):
            correct = sigle_eval(id, prediction)
            results.append({"id": id, "correctness": correct})

        # Calculate the three accuracy rates
        accuracy_by_question = evaluate_accuracy_by_question(results)
        accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
        proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)

        return {
            "accuracy_by_question": accuracy_by_question,
            "accuracy_by_sub_question": accuracy_by_sub_question,
            "proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
        }

async def ask_and_print(question, system_prompt):
    from metagpt.llm import LLM
    llm = LLM()
    rsp = await llm.aask(question, system_msgs=[system_prompt])
    return rsp

async def reformat(question, format, response):
    system_prompt = "You are a helpful assistant."
    demons = """\Format{{
@shapiro_wilk_statistic[test_statistic]
@shapiro_wilk_p_value[p_value]
where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places.
where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places.
}}
\Answer{{
@shapiro_wilk_statistic[0.56]
@shapiro_wilk_p_value[0.0002]
}}

\Format{{
@total_votes_outliers_num[outlier_num]
where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column.
}}
\Answer{{
@total_votes_outliers[10]
}}
"""
    reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}.
Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
The format requirements of this question is:
{format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""

    messages = [{"role": "user", "content": question}]
    messages.append({"role": "assistant", "content": response})
    messages.append({"role": "user", "content": reformat_template.format(demons=demons, format=format)})
    rsp = await ask_and_print(messages, system_prompt)
    return rsp


def extract_content_after_at_symbol(text):
    # 使用正则表达式匹配 @ 后面的内容，直到遇到非字母、数字、下划线或方括号
    pattern = r'@([a-zA-Z0-9_]+(?:\[[^\]]*\])?)'
    
    # 查找所有匹配的内容
    matches = re.findall(pattern, text)
    return "@".join(matches)
    # return matches
def parse_prediction(prediction: str) -> dict:
    """Parse the prediction string into a dictionary of metric-value pairs."""
    pred_dict = {}
    prediction = extract_content_after_at_symbol(prediction)
    
    for pred in prediction.split("@"):
        if pred == "":
            continue
        temp = re.split(r'[\[\]]', pred.strip())
        temp = [s.replace(',', '') for s in temp]
        parts = [s for s in temp if s]
        metric = parts[0].strip().replace(",", "")
        value = parts[-1].replace(",", "").replace(":", "")
        
        try:
            value = float(value)
        except ValueError:
            pass  # Keep value as string if conversion fails
        
        pred_dict[metric] = value
    return pred_dict

def compare_predictions(pred_dict: dict, true_label: list) -> bool:
    """Compare each prediction with the corresponding true label."""
    sorted_true_label = sorted(true_label, key=lambda x: x[0])

    for metric, true_value in sorted_true_label:
        try:
            true_value = float(true_value)
        except ValueError:
            true_value = true_value.replace(',', '')
        
        if isinstance(true_value, (int, float)) and (
            metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
        ):
            return False
        if isinstance(true_value, str) and (metric not in pred_dict or str(pred_dict[metric]) != str(true_value)):
            return False

    return True


    
