import json
import re
import string
from pathlib import Path
from typing import Any, Dict, List, Optional
import logging
from tabulate import tabulate

logger = logging.getLogger(__name__)


def normalize_str(input_str, remove_punct=True) -> str:
    no_spaces = re.sub(r"\s", "", input_str)
    if remove_punct:
        translator = str.maketrans("", "", string.punctuation)
        return no_spaces.lower().translate(translator)
    else:
        return no_spaces.lower()


def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]:
    if char_list is None:
        char_list = [",", ";"]
    pattern = f"[{''.join(char_list)}]"
    return re.split(pattern, s)


def normalize_number_str(number_str: str) -> float:
    for char in ["$", "%", ","]:
        number_str = number_str.replace(char, "")
    try:
        return float(number_str)
    except ValueError:
        logger.error(f"String {number_str} cannot be normalized to number str.")
        return float("inf")


def question_scorer(model_answer: str, ground_truth: str) -> bool:
    def is_float(element: Any) -> bool:
        try:
            float(element)
            return True
        except ValueError:
            return False

    try:
        if is_float(ground_truth):
            logger.info(f"Evaluating {model_answer} as a number.")
            normalized_answer = normalize_number_str(model_answer)
            return normalized_answer == float(ground_truth)

        elif any(char in ground_truth for char in [",", ";"]):
            logger.info(f"Evaluating {model_answer} as a comma separated list.")
            gt_elems = split_string(ground_truth)
            ma_elems = split_string(model_answer)

            if len(gt_elems) != len(ma_elems):
                logger.warning("Answer lists have different lengths, returning False.")
                return False

            comparisons = []
            for ma_elem, gt_elem in zip(ma_elems, gt_elems):
                if is_float(gt_elem):
                    normalized_ma_elem = normalize_number_str(ma_elem)
                    comparisons.append(normalized_ma_elem == float(gt_elem))
                else:
                    ma_elem = normalize_str(ma_elem, remove_punct=False)
                    gt_elem = normalize_str(gt_elem, remove_punct=False)
                    comparisons.append(ma_elem == gt_elem)
            return all(comparisons)
        else:
            logger.info(f"Evaluating {model_answer} as a string.")
            ma_elem = normalize_str(model_answer)
            gt_elem = normalize_str(ground_truth)
            return ma_elem == gt_elem
    except Exception as e:
        logger.error(f"Error during evaluation: {e}")
        return False


def load_dataset_meta(path: str, split: str = "validation"):
    data_dir = Path(path) / split

    dataset = []
    with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
        lines = metaf.readlines()
        for line in lines:
            data = json.loads(line)
            if data["task_id"] == "0-0-0-0-0":
                continue
            if data["file_name"]:
                data["file_name"] = data_dir / data["file_name"]
            dataset.append(data)
    return dataset


def load_dataset_meta_dict(path: str, split: str = "validation"):
    data_dir = Path(path) / split

    dataset = {}
    with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
        lines = metaf.readlines()
        for line in lines:
            data = json.loads(line)
            if data["task_id"] == "0-0-0-0-0":
                continue
            if data["file_name"]:
                data["file_name"] = data_dir / data["file_name"]
            dataset[data["task_id"]] = data
    return dataset


def add_file_path(
    task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation"
):
    if task["file_name"]:
        file_path = Path(f"{file_path}/{split}") / task["file_name"]
        if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
            task["Question"] += f" Here are the necessary document files: {file_path}"

        elif file_path.suffix in [".jpg", ".jpeg", ".png"]:
            task["Question"] += f" Here are the necessary image files: {file_path}"

        elif file_path.suffix in [".xlsx", "xls", ".csv"]:
            task["Question"] += (
                f" Here are the necessary table files: {file_path}, for processing excel file,"
                " you can use the excel tool or write python code to process the file"
                " step-by-step and get the information."
            )
        elif file_path.suffix in [".py"]:
            task["Question"] += f" Here are the necessary python files: {file_path}"

        else:
            task["Question"] += f" Here are the necessary files: {file_path}"

    return task


def report_results(entries):
    # Initialize counters
    total_entries = len(entries)
    total_correct = 0

    # Initialize level statistics
    level_stats = {}

    # Process each entry
    for entry in entries:
        level = entry.get("level")
        is_correct = entry.get("is_correct", False)

        # Initialize level stats if not already present
        if level not in level_stats:
            level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0}

        # Update counters
        level_stats[level]["total"] += 1
        if is_correct:
            total_correct += 1
            level_stats[level]["correct"] += 1

    # Calculate accuracy for each level
    for level, stats in level_stats.items():
        if stats["total"] > 0:
            stats["accuracy"] = (stats["correct"] / stats["total"]) * 100

    # Print overall statistics with colorful logging
    logger.info("Overall Statistics:")
    overall_accuracy = (total_correct / total_entries) * 100

    # Create overall statistics table
    overall_table = [
        ["Total Entries", total_entries],
        ["Total Correct", total_correct],
        ["Overall Accuracy", f"{overall_accuracy:.2f}%"],
    ]
    logger.success(tabulate(overall_table, tablefmt="grid"))
    logger.info("")

    # Create level statistics table
    logger.info("Statistics by Level:")
    level_table = []
    headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"]

    for level in sorted(level_stats.keys()):
        stats = level_stats[level]
        level_table.append(
            [level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"]
        )

    logger.success(tabulate(level_table, headers=headers, tablefmt="grid"))