import json
from typing import Any, Optional
from decimal import Decimal
import os

from src.tasks.instance.db_bench.task import (
    DBBench,
    DBBenchDatasetItem,
    AnswerType,
    DirectTypeAnswerValidator,
    DBBenchContainer,
    DBBenchSkillUtility,
)
from src.typings import LoggerConfig
from src.utils import SingletonLogger
from src.factories.data.standard_v0121.utility import StandardDataFactoryUtility


class PseudoDBBench:
    def __init__(self) -> None:
        self.container = DBBenchContainer()
        self.current_dataset_item: Optional[DBBenchDatasetItem] = None

    def _get_current_dataset_item(self) -> DBBenchDatasetItem:
        assert self.current_dataset_item is not None
        return self.current_dataset_item


class DBBenchStandardDataFactory:
    """
    The class is used to format the data generated by lqk.
    """

    def __init__(
        self, raw_data_path: str, processed_data_path: str, log_file_path: str
    ):
        self.raw_data_path = raw_data_path
        self.processed_data_path = processed_data_path
        self.pseudo_db_bench = PseudoDBBench()
        logger_config = LoggerConfig(
            level="INFO",
            log_file_path=log_file_path,
            logger_name="db_bench_standard_data_factory",
        )
        self.logger = SingletonLogger.get_instance(logger_config)

    @staticmethod
    def _format_entry(entry: dict[str, Any]) -> dict[str, Any]:
        processed_entry = {
            "instruction": entry["description"],
            "sql": entry["sql"],
            "table_info": {
                "name": entry["table_info"]["name"],
                "columns": entry["table_info"]["columns"],
                "rows": entry["table_info"]["rows"],
            },
            "answer_md5": entry.get("answer_md5", None),
            "answer_direct": entry.get("answer_direct", None),
            "skills": entry["skills"],
        }
        if processed_entry["answer_direct"] is None:
            del processed_entry["answer_direct"]
        else:
            del processed_entry["answer_md5"]
        return processed_entry

    def _get_structured_ground_truth(
        self, dataset_item: DBBenchDatasetItem
    ) -> str | list[tuple[str | int | float, ...]]:
        self.pseudo_db_bench.container.conn.reconnect()
        cursor = self.pseudo_db_bench.container.conn.cursor()
        cursor.execute(f"use `{dataset_item.database_name}`")
        cursor.fetchall()
        cursor.execute(dataset_item.answer_info.ground_truth_sql)
        structured_sql_output = cursor.fetchall()
        self.pseudo_db_bench.container.conn.commit()
        ground_truth: str | list[tuple[str | int | float, ...]]
        match dataset_item.answer_info.answer_type:
            case AnswerType.MD5:
                # parameter "answer" is not going to be used, pass an empty string
                self.pseudo_db_bench.current_dataset_item = dataset_item
                ground_truth = DBBench._get_task_output(  # noqa
                    self.pseudo_db_bench, ""  # type: ignore[arg-type]
                )["answer"]
            case AnswerType.DIRECT:
                ground_truth = structured_sql_output  # type: ignore[assignment]
                # region Convert Decimal to float
                for row_index, raw_row_tuple in enumerate(ground_truth):
                    processed_row_list = []
                    for value in raw_row_tuple:
                        if isinstance(value, Decimal):
                            value = float(value)
                        processed_row_list.append(value)
                    ground_truth[row_index] = tuple(processed_row_list)  # type: ignore[index]
                # endregion
            case _:
                raise TypeError()
        return ground_truth

    def process(self) -> None:
        raw_data_dict: dict[str, Any] = json.load(open(self.raw_data_path, "r"))
        processed_data_dict: dict[str, Any] = {}
        for sample_index, entry in raw_data_dict.items():
            # region Format entry
            processed_entry = self._format_entry(entry)
            # endregion
            # region Prepare dataset_item and database
            dataset_item = DBBench._construct_dataset_item(processed_entry)  # noqa
            init_sql = DBBench._build_init_sql(dataset_item)  # noqa
            self.pseudo_db_bench.container.execute(init_sql)
            # endregion
            # region Get structured ground truth and set it
            ground_truth = self._get_structured_ground_truth(dataset_item)
            match dataset_item.answer_info.answer_type:
                case AnswerType.MD5:
                    assert "answer_md5" in processed_entry.keys()
                    processed_entry["answer_md5"] = ground_truth
                case AnswerType.DIRECT:
                    assert "answer_direct" in processed_entry.keys()
                    for row in ground_truth:
                        for value in row:
                            assert isinstance(value, (str, int, float))
                    processed_entry["answer_direct"] = ground_truth
                case _:
                    raise TypeError()
            processed_data_dict[sample_index] = processed_entry
            # endregion
            # region Clean up, log progress
            self.pseudo_db_bench.container.execute(
                f"drop database `{dataset_item.database_name}`"  # noqa
            )
            self.logger.info(f"sample_index: {sample_index:<3}. Processing completed.")
            # endregion
        json.dump(
            processed_data_dict, open(self.processed_data_path, "w"), indent=2  # noqa
        )

    def validate(self) -> None:
        processed_data_dict: dict[str, Any] = json.load(
            open(self.processed_data_path, "r")
        )
        for sample_index, entry in processed_data_dict.items():
            # region Prepare dataset_item and database
            dataset_item = DBBench._construct_dataset_item(entry)  # noqa
            init_sql = DBBench._build_init_sql(dataset_item)  # noqa
            self.pseudo_db_bench.container.execute(init_sql)
            # endregion
            # region Execute sql and validate the answer
            sql_execution_result = self.pseudo_db_bench.container.execute(
                dataset_item.answer_info.ground_truth_sql, dataset_item.database_name
            )
            self.pseudo_db_bench.current_dataset_item = dataset_item
            answer_dict = DBBench._get_task_output(  # noqa
                self.pseudo_db_bench,  # type: ignore[arg-type]
                sql_execution_result,
            )
            agent_answer = answer_dict["answer"]
            match dataset_item.answer_info.answer_type:
                case AnswerType.MD5:
                    correct_flag = agent_answer == dataset_item.answer_info.answer_md5
                case AnswerType.DIRECT:
                    ground_truth = dataset_item.answer_info.answer_direct
                    assert ground_truth is not None
                    correct_flag = DirectTypeAnswerValidator.validate(
                        agent_answer, ground_truth
                    )
                case _:
                    raise TypeError()
            # endregion
            # region Clean up, log progress
            self.pseudo_db_bench.container.execute(
                f"drop database `{dataset_item.database_name}`"  # noqa
            )
            if correct_flag:
                self.logger.info(
                    f"sample_index: {sample_index:<3}. "
                    f"answer_type: {dataset_item.answer_info.answer_type:<6}. "
                    f"Validation passed."
                )
            else:
                self.logger.error(
                    f"sample_index: {sample_index:<3}. "
                    f"answer_type: {dataset_item.answer_info.answer_type:<6}. "
                    f"Validation failed."
                )
            # endregion

    def count_skill(self) -> None:
        # region Preparation
        processed_data_dict: dict[str, Any] = json.load(
            open(self.processed_data_path, "r")
        )
        sample_level_to_count_dict: dict[int, int] = {
            key: 0 for key in DBBenchSkillUtility.get_skill_level_list()
        }
        skill_to_count_dict: dict[str, int] = {
            key: 0 for key in DBBenchSkillUtility.get_all_skill_list()
        }
        effective_sill_to_count_dict: dict[str, int] = {
            key: 0 for key in DBBenchSkillUtility.get_all_skill_list()
        }
        # endregion
        # region Count skill and related information
        for sample_index, entry in processed_data_dict.items():
            # region Preparation
            dataset_item = DBBench._construct_dataset_item(entry)  # noqa
            difficulty_level: int = dataset_item.get_difficulty_level()
            effective_skill_set = set()
            # endregion
            # region Count skill, set maximum_skill_level and effective_skill_set
            for skill in dataset_item.get_skill_list():
                skill_to_count_dict[skill] += 1
                skill_level = DBBenchSkillUtility.get_skill_level(skill)
                if skill_level == difficulty_level:
                    effective_skill_set.add(skill)
            # endregion
            # region Set outer variables
            for effective_skill in effective_skill_set:
                effective_sill_to_count_dict[effective_skill] += 1
            sample_level_to_count_dict[difficulty_level] += 1
            # endregion
        # endregion
        # region Log information
        self.logger.info("Sample level information:")
        self.logger.info("| Level | Count")
        for level, count in sample_level_to_count_dict.items():
            self.logger.info(f"| {level:<5} | {count:<5} |")
        self.logger.info("Skill information:")
        self.logger.info(f"| {'Skill':<42} | Level | Count | Effective Count")
        for skill in skill_to_count_dict.keys():
            skill_level = DBBenchSkillUtility.get_skill_level(skill)
            self.logger.info(
                f"| {skill:<42} | "
                f"{skill_level:<5} | "
                f"{skill_to_count_dict[skill]:<5} | "
                f"{effective_sill_to_count_dict[skill]:<15} |"
            )
        # endregion


def main() -> None:
    # region Process data
    log_file_path = "./outputs/db_bench_standard_data_factory.log"
    root_path = "data/v0121/db_bench/v0124"
    for raw_data_path, processed_data_path in [
        (
            os.path.join(root_path, "output3_total.json"),
            os.path.join(root_path, "output3_total_processed.json"),
        ),
        (
            os.path.join(root_path, "output4_total.json"),
            os.path.join(root_path, "output4_total_processed.json"),
        ),
        (
            os.path.join(root_path, "output5_total.json"),
            os.path.join(root_path, "output5_total_processed.json"),
        ),
        (
            os.path.join(root_path, "train_db.json"),
            os.path.join(root_path, "train_db_processed.json"),
        ),
    ]:
        factory = DBBenchStandardDataFactory(
            raw_data_path, processed_data_path, log_file_path
        )
        # factory.process()
        factory.validate()
    # # endregion
    # # region Merge data
    # source_identifier_list = ["output3_total", "output4_total", "output5_total", "train_db"]
    # source_info_list = [
    #     (
    #         source_identifier,
    #         os.path.join(
    #             root_path, f"{source_identifier}_processed.json"
    #         ),
    #     )
    #     for source_identifier in source_identifier_list
    # ]
    # output_merged_dict_path = "data/v0121/db_bench/db_bench.json"
    # StandardDataFactoryUtility.merge_data_dict(
    #     source_info_list,
    #     output_merged_dict_path,
    #     "data/v0121/db_bench/merged_source_information.json",
    #     lambda x: x["instruction"],
    # )
    # # endregion
    # region Count skill
    factory = DBBenchStandardDataFactory(
        "",
        "data/v0121/db_bench/db_bench.json",
        "./outputs/db_bench_standard_data_factory.log",
    )
    factory.count_skill()
    # endregion


if __name__ == "__main__":
    main()
