from abc import ABC, abstractmethod
from typing import Any, Sequence

from pathlib import Path
import asyncio
import json
import os
import re

from tqdm import tqdm

from olym_gen.utils.utils import get_logger, get_generator_base
from olym_gen.generator.problem_proof_generator import (
    LogProblemProofMixin,
    UserProblemProofMixin,
    ResumeMixin,
)
from olym_gen.generator.base_generator import SystemPromptMixin, GeneratorBase

logger = get_logger()


class JsonFormatError(Exception):
    """Custom exception for correct but unexpected json format."""

    pass


class JsonFormatOutputMixin[ResponseJsonType](ABC):

    @abstractmethod
    def _check_json_format(
        self, json_object: Any, other_info: dict[str, Any] | None = None
    ) -> None:
        """
        Check whether a json.loads object is in the expected format. If not, raise a JsonFormatError. Notice that this function is not used to check whether a string can be parsed as valid JSON.
        Args:
            json_object (Any): The JSON object to check. It should be the json.loads object.
            other_info (dict[str, Any] | None): Additional information to check the JSON format. This can be used to provide more context for the check. The param use the return value of `prepare_info_for_json_check` function.
        Raises:
            JsonFormatError: If the JSON object is not in the expected format.
        """
        ...

    def prepare_info_for_json_check(self, problem: str, proof: str) -> dict[str, Any]:
        """
        Prepare additional information for checking the JSON format. The return value will be passed to `_check_json_format` function as `other_info` parameter. Default is an empty dictionary.
        """
        return dict()

    @staticmethod
    def pre_fix_escape(
        response: str,
        allow_escape: re.Pattern | None = None,
        match_latex: bool | Sequence[str] = True,
    ) -> str:
        r"""
        Pre-process the response string to fix common escape issues before JSON decoding. Try to recognize common latex commands that could be incorrectly read as escape sequences, fix them by doubling the backslashes, and return the modified string.

        Args:
            response (str): The JSON response string that may contain invalid escape sequences.
            allow_escape (re.Pattern | None): A list of escape sequences that are allowed. Each should be 2 characters (e.g., '\n', '\t'). If None, allow `\n`, `\"` and `\\` in the response. These sequences will not be doubled, except fot it is further matched by match_latex. Default: `re.compile(r'\\[ntr\\"]')`.
            match_latex (bool | Sequence[str]): If True, a string beginning with a allow_escape string but matching a latex command will be doubled. A sequence of strings can also be provided to specify the LaTeX commands to match. If True, use the default LaTeX commands: [`\not`, `\neq`]
        """
        if allow_escape is None:

            allow_escape = re.compile(r'\\[ntr\\"]')
            # allow_escape = [r'\n', r'\\', r'\"']

        if match_latex is True:
            match_latex = (
                [  # collect from https://www.bu.edu/math/files/2013/08/LongTeX1.pdf
                    r"\not",
                    r"\neq",
                    r"\nabla",
                    "\ne",
                    r"\neg",
                    r"\newcommand",
                    r"\newenvironment",
                    r"\newline",
                    r"\newpage",
                    r"\newtheorem",
                    r"\ni",
                    r"\nu",
                    r"\alpha",
                    r"\angle",
                    r"\approx",
                    r"\arccos",
                    r"\arcsin",
                    r"\arctan",
                    r"\arg",
                    r"\ast",
                    r"\atop",
                    r"\aleph",
                    r"\amalg",
                    r"\and",
                    r"\asymp",
                    r"\backslash",
                    r"\bar",
                    r"\beta",
                    r"\begin",
                    r"\bf",
                    r"\bigcap",
                    r"\bigcup",
                    r"\bigcirc",
                    r"\bigodot",
                    r"\bigoplus",
                    r"\bigotimes",
                    r"\bigsqcup",
                    r"\bigvee",
                    r"\bigwedge",
                    r"\boldmath",
                    r"\bot",
                    r"\Box",
                    r"\breve",
                    r"\bullet",
                    r"\ref",
                    r"\rangle",
                    r"\rbrace",
                    r"\rbrack",
                    "\rceil",
                    r"\renewcommand",
                    r"\renewenvironment",
                    r"\rho",
                    r"\right",
                    r"\rightarrow",
                    r"\rm",
                    r"\rq",
                    r"\tau",
                    r"\tan",
                    r"\tanh",
                    r"\tfrac",
                    r"\text",
                    r"\theta",
                    r"\therefore",
                    r"\thinspace",
                    r"\tilde",
                    r"\times",
                    r"\to",
                    r"\top",
                    r"\triangle",
                    r"\triangledown",
                    r"\theta",
                    r"\topfraction",
                    r"\triangle",
                    r"\triangleleft",
                    r"\triangleright",
                    r"\tr",
                    r"\tt",
                ]
            )
        elif isinstance(match_latex, Sequence):
            # If match_latex is a sequence, use it directly
            match_latex = [str(cmd) for cmd in match_latex]
        else:
            match_latex = []

        # When no_escape is True, double all backslashes except for allowed escape sequences
        backslash_pattern = (
            r"(?<!\\)(\\)"  # Match a backslash not preceded by another backslash
        )
        all_backslash_index = [
            m.start() for m in re.finditer(backslash_pattern, response)
        ]
        all_allow_escape = [m.start() for m in re.finditer(allow_escape, response)]
        latex_pattern = re.compile(r"\\[abnrt][a-zA-Z]*\b")
        all_recall_by_latex = [
            m.start()
            for m in re.finditer(latex_pattern, response)
            if m.group(0) in match_latex
        ]
        # print('debug', all_backslash_index)
        # print('debug', all_allow_escape)
        # print('debug', all_recall_by_latex)
        all_need_double_backslash = list(
            set(all_backslash_index)
            - (set(all_allow_escape) - set(all_recall_by_latex))
        )
        all_need_double_backslash.sort()
        # Double the backslashes that are not in the allowed escape sequences
        for index in all_need_double_backslash:
            assert index < len(
                response
            ), f"Index {index} is out of bounds for response of length {len(response)}."
            assert (
                response[index] == "\\"
            ), f"Character at index {index} should be a backslash, but got: {response[index]}"
        parts = [
            response[i + 1 : j]
            for i, j in zip(
                [-1] + all_need_double_backslash,
                all_need_double_backslash + [len(response)],
            )
        ]
        response = "\\\\".join(parts)
        return response

    @staticmethod
    def fix_escape(response: str, error: str) -> str:
        """
        Deepseek models sometimes will fail to escape the backslash correctly, especially when the output has latex command and also use the json format. Here we fix the escape issue by replacing the invalid escape sequences with a valid one. This function use the error message to find the position of the invalid escape sequence and replace it with a double backslash. If there is no error message, please use `pre_fix_escape` function to pre-process the response string.

        Args:
            response (str): The JSON response string that may contain invalid escape sequences.
            error (str | None): The error message from the JSON decoding process, which should contain 'Invalid \\escape'. Can be None when no_escape=True.
            no_escape (bool): If True, avoid all escape sequences except those listed in allow_escapes.
        """

        if "Invalid \\escape" not in error:
            raise ValueError(
                f"Expected 'Invalid \\escape' in error message, but got: {error}. Please use pre_fix_escape function to pre-process the response string."
            )

        # find the invalid escape place, an example error information is: Error: Invalid \escape: line 5 column 614 (char 669).
        position = re.search(r"line \d+ column \d+ \(char (\d+)\)", error)
        if position is None:
            raise ValueError(
                f"Failed to find the position of the invalid escape in the error message: {error}. The response is: {response}."
            )

        char_index = int(position.group(1))
        # Ensure we don't go out of bounds
        assert (
            char_index < len(response) and response[char_index] == "\\"
        ), f"Character index {char_index} should be less than the length of response {len(response)} and should be a backslash, but got: {response[char_index] if char_index < len(response) else 'OUT_OF_BOUNDS'}"

        response = response[:char_index] + "\\\\" + response[char_index + 1 :]
        logger.debug(
            f"Try to fix the escape issue by replacing the backslash with a double backslash at index {char_index}.\nThe response is now: {response}"
        )

        return response

    def _response_to_json(
        self, response: str, info_for_json_check: dict[str, Any] | None = None
    ) -> ResponseJsonType | None:
        """
        Convert a response string to a JSON object and check its format.
        The `response` string is expected to be a valid JSON string.
        Use function `prepare_info_for_json_check` to get `info_for_json_check`.
        """
        # first try to pre-process the response string to fix common escape issues before JSON decoding
        response = self.pre_fix_escape(response)
        while True:
            try:
                json_object = json.loads(response)
                self._check_json_format(json_object, info_for_json_check)
                return json_object
            except json.JSONDecodeError as e:
                if "Invalid \\escape" in str(e):
                    response = self.fix_escape(response, str(e))
                    continue
                else:
                    logger.error(
                        f"Failed to decode JSON from response: {response}. Error: {e}."
                    )
                    return None
            except JsonFormatError as e:
                logger.error(
                    f"JSON format is not corresponding to the expected one: {e}. Response: {response}."
                )
                return None
            except Exception as e:
                logger.error(
                    f"Unexpected error when processing response: {response}. Error: {e}."
                )
                return None

    def remove_invalid_json(self, path: str) -> int:
        """
        Remove invalid JSON files from the given path.
        Args:
            path (str): The path to the directory containing JSON files.
        Returns:
            int: The number of invalid JSON files removed.
        """
        invalid_count = 0
        for file in Path(path).glob("*.json"):
            try:
                with open(file, "r", encoding="utf-8") as f:
                    d = json.load(f)
                    if "pass_check" not in d:
                        raise ValueError(
                            f"JSON file {file} does not contain 'pass_check' key. We cannot remove invalid json file without this key."
                        )
                    if not d["pass_check"]:
                        logger.warning(f"Removing invalid JSON file: {file}")
                        file.unlink()
                        invalid_count += 1
            except json.JSONDecodeError:
                logger.warning(f"Removing invalid JSON file: {file}")
                file.unlink()
                invalid_count += 1
        return invalid_count


class JsonFormatProblemProofGeneratorBase[ResponseJsonType](
    ResumeMixin,
    JsonFormatOutputMixin[ResponseJsonType],
    LogProblemProofMixin,
    UserProblemProofMixin,
    SystemPromptMixin,
    ABC,
):
    """
    Base class for generators that produce JSON formatted outputs.
    Now uses composition instead of inheritance for GeneratorBase functionality.
    """
    
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        """
        Initialize the generator with provider configuration.
        Creates the generator_base using the provided configuration.
        """
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")
    
    @property
    def batch_id(self):
        """Delegate to the composed generator_base"""
        return self.generator_base.batch_id if self.generator_base else None
    
    @batch_id.setter  
    def batch_id(self, value):
        """Delegate to the composed generator_base"""
        if self.generator_base:
            self.generator_base.batch_id = value
    
    async def single_turn_request(self, *args, **kwargs):
        """Delegate to the composed generator_base"""
        if not self.generator_base:
            raise ValueError("generator_base not initialized. Make sure to call get_generator_base in __init__")
        return await self.generator_base.single_turn_request(*args, **kwargs)

    async def _generate(
        self,
        problem: str,
        proof: str,
        problem_index: int,
        proof_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
    ) -> list[tuple[str, str | ResponseJsonType, bool] | None]:
        """
        Generate a response for the given problem and proof, checking the JSON format of the response.
        """

        logger.debug(self.log_step_start(problem_index, proof_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(problem, proof)
        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            max_tokens=max_tokens,
            use_json=True,
        )
        json_decode_return_list: list[
            tuple[str, str | ResponseJsonType, bool] | None
        ] = []

        # check for the return list
        for generation in return_list:
            if generation is None:
                json_decode_return_list.append(None)
                continue
            if not isinstance(generation, tuple) or len(generation) != 2:
                raise ValueError(
                    f"Expected a tuple of (thinking, answer), got {generation}"
                )
            info_for_json_check = self.prepare_info_for_json_check(
                problem=problem, proof=proof
            )
            json_object = self._response_to_json(generation[1], info_for_json_check)
            pass_check = json_object is not None

            json_decode_return_list.append(
                (
                    generation[0],
                    json_object if json_object is not None else generation[1],
                    pass_check,
                )
            )
        logger.debug(self.log_step_finish(problem_index, proof_index))

        return json_decode_return_list

    async def process(
        self,
        file: str,
        lines: int | None = None,
        indexes: list[int] | None = None,
        num_returns: int = 1,
        save_path: str | None = None,
        num_worker: int = 1,
        resume: bool = False,
        max_tokens=64_000,
        retry_times: int = 0,
        async_mode: bool = True,
        only_solve_id_conflict: bool = False,
        batch_id: str | None = None,
    ):
        """
        Process the input file and generate responses for each problem-proof pair.
        """
        self.batch_id = batch_id
        save_path = save_path if save_path is not None else self._default_save_path
        logger.info(self.log_start(file, num_worker, num_returns))

        pairs = self.load_problems_and_proofs(
            Path(file),
            lines=lines,
            indexes=indexes,
        )

        semaphore = asyncio.Semaphore(num_worker)

        async def process_pair(pair):
            problem, proof, field, problem_idx, proof_idx = (
                pair["problem"],
                pair["proof"],
                pair["field"],
                pair["problem_index"],
                pair["proof_index"],
            )
            need_generation = self.need_generation_idx(
                resume=resume,
                problem=problem,
                answer=proof,
                problem_index=problem_idx,
                proof_index=proof_idx,
                save_path=save_path,
                num_returns=num_returns,
                pairs=pairs,
            )

            if only_solve_id_conflict:
                return
            
            if len(need_generation) == 0:
                logger.info(
                    f"Skipping problem {problem_idx}, proof {proof_idx} as all generations exist."
                )
                return

            response = await self._generate(
                problem,
                proof,
                shared_semaphore=semaphore,
                num_returns=len(need_generation),
                problem_index=problem_idx,
                proof_index=proof_idx,
                max_tokens=max_tokens,
            )

            for j, res in enumerate(response):
                if res is None:
                    logger.warning(
                        f"Response for problem {problem_idx}, proof {proof_idx} is None, skipping."
                    )
                    continue
                (thinking, solution, pass_check) = res
                if not pass_check:
                    logger.warning(
                        f"Response for problem {problem_idx}, proof {proof_idx} failed JSON format check, skipping."
                    )
                    continue
                save_name = self.save_name(
                    save_path,
                    problem_index=problem_idx,
                    proof_index=proof_idx,
                    generation_index=need_generation[j],
                )
                logger.debug(f"Saving response to {save_name}")
                with open(save_name, "w", encoding="UTF-8") as f:
                    json.dump(
                        {
                            "pass_check": pass_check,
                            self._save_question_name: problem,
                            self._save_orig_solution_name: proof,
                            self._save_thinking_name: thinking,
                            self._save_solution_name: solution,
                            self._save_field_name: field,
                        },
                        f,
                        indent=4,
                        ensure_ascii=False,
                    )

        os.makedirs(save_path, exist_ok=True)

        for retry in range(retry_times + 1):
            if retry == 0:  # In the first time, we will not check the json files.
                pass
            else:
                rm_files = self.remove_invalid_json(save_path)
                if rm_files:
                    logger.info(
                        f"Find {rm_files} invalid JSON files in retry time {retry}. We will remove them and retry."
                    )
                    resume = True  # to ensure we only reprocess the pairs with invalid JSON files.
                else:
                    logger.info(f"All JSON files have passed check. We will not retry.")
                    break

            if not async_mode:
                for pair in tqdm(
                    pairs,
                    desc=(
                        "Processing pairs"
                        if retry == 0
                        else f"Retrying pairs {retry + 1}"
                    ),
                ):
                    await process_pair(pair)
            else:
                tasks = [process_pair(pair) for pair in pairs]
                for task in tqdm(
                    asyncio.as_completed(tasks),
                    total=len(tasks),
                    desc=(
                        "Processing pairs"
                        if retry == 0
                        else f"Retrying pairs {retry + 1}"
                    ),
                ):
                    await task

        logger.info(self.log_finish(file, len(pairs), num_returns, save_path))
