from abc import ABC, abstractmethod
from typing import Any

from pathlib import Path
import asyncio
import json
import os
import re

from tqdm import tqdm

from olym_gen.utils.utils import get_logger, get_generator_base
from olym_gen.generator.problem_proof_generator import LogProblemProofMixin
from olym_gen.generator.json_format_generator import JsonFormatOutputMixin, JsonFormatError
from olym_gen.generator.base_generator import SystemPromptMixin, GeneratorBase
from typing import TypedDict

logger = get_logger()

class UserProblemProofPairMixin():
    """
    Mixin class for generating user prompts for problem proof pair generation.
    """

    def _user_prompt(self, problem: str, orig_proof: str, new_proof: str) -> str:
        """ Generate the user prompt for the problem and proof pair.
        """
        return f"Problem: {problem}\n\nProof 1: {orig_proof}\n\nProof 2: {new_proof}\n\n"

AnswerCompareResponse = TypedDict(
    'AnswerCompareResponse',
    {
        'rationale': str,
        'answer_same': bool
    }
)

class ExtractAnswerMixin(ABC):
    """
    Mixin class for extracting the answer from the proof.
    """

    def extract_boxed_answer(self, proof: str) -> str:
        """
        Extract the answer from the proof.
        """
        pattern = r'\\boxed\s*\{'
        matches = list(re.finditer(pattern, proof))
        if not matches:
            logger.warning("No \\boxed found in the proof.")
            return ""
        
        # 选择最后一个 \boxed{ 开始的位置
        last_match = matches[-1]
        start = last_match.end()
        stack = 1
        end = start
        while end < len(proof) and stack > 0:
            if proof[end] == '{':
                stack += 1
            elif proof[end] == '}':
                stack -= 1
            end += 1
        
        if stack == 0:
            return proof[start:end-1]
        else:
            # 如果大括号不匹配，返回空字符串或抛出异常
            logger.warning("Unmatched braces in the proof.")
            return ""

class JsonFormatProblemProofPairGenerator( JsonFormatOutputMixin[AnswerCompareResponse], LogProblemProofMixin, ExtractAnswerMixin, UserProblemProofPairMixin, SystemPromptMixin):
    """
    Base class for generators that produce JSON formatted outputs.
    """
    
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        # Use composition instead of inheritance for GeneratorBase
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")
    
    async def single_turn_request(self, *args, **kwargs):
        """Delegate to the composed generator_base"""
        return await self.generator_base.single_turn_request(*args, **kwargs)

    @property
    def system_prompt_file(self) -> str:
        return 'prompts/compare_answer_system_prompt.txt'

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Checking proof for problem {problem_index} and proof {proof_index}..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished checking proof for problem {problem_index} and proof {proof_index}."
    
    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will check the proofs for the problems in the file."
    
    def log_finish(self, file: str, num_pairs: int, num_returns: int, save_path: str) -> str:
        return f'Finished processing the file {file}. Verified {num_pairs} problem-proof pairs with {num_returns} checks each. The thinking process and checking results are saved to {save_path}.'

    async def _generate(
        self,
        problem: str,
        orig_proof: str,
        new_proof: str,
        problem_index: int,
        proof_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
    ) -> list[tuple[str, str | AnswerCompareResponse, bool] | None]:
        """
        Generate a response for the given problem and proof, checking the JSON format of the response.
        """

        logger.debug(self.log_step_start(problem_index, proof_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(problem, orig_proof, new_proof)
        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            use_json=True,
            max_tokens=8_192
        )
        json_decode_return_list: list[tuple[str, str | AnswerCompareResponse, bool] | None] = []

        # check for the return list
        for generation in return_list:
            if generation is None:
                json_decode_return_list.append(None)
                continue
            if not isinstance(generation, tuple) or len(generation) != 2:
                raise ValueError(
                    f"Expected a tuple of (thinking, answer), got {generation}"
                )
            json_object = self._response_to_json(generation[1])
            pass_check = json_object is not None

            json_decode_return_list.append(
                (generation[0], json_object if json_object is not None else generation[1], pass_check)
            )

        logger.debug(self.log_step_finish(problem_index, proof_index))

        return json_decode_return_list

    async def process(
        self,
        load_path: str,
        lines: int,
        num_worker: int = 1,
        resume: bool = False,
        async_mode: bool = True,
        batch_id: str | None = None,
    ):
        """
        Process the input file and generate response for the problem-proof pair.
        """
        self.batch_id = batch_id
        file_list = []
        incorrect_files = []
        count = 0
        correct = 0
        for file in os.listdir(load_path):
            if not re.match(r"problem_(\d+)_generate_(\d+).json", file):
                continue
            logger.debug(f"Processing file: {file}")
            with open(os.path.join(load_path, file), 'r', encoding='UTF-8') as f:
                data = json.load(f)
                if 'pass_check' in data and resume:
                    logger.debug(f"Skipping file {file} as it already has pass_check field")
                    count += 1
                    if data['pass_check']:
                        pass
                    else:
                        incorrect_files.append(file)
                    continue
                file_list.append(file)
                count += 1
            if count >= lines:
                break

        semaphore = asyncio.Semaphore(num_worker)

        async def process_file(file: str):
            with open(os.path.join(load_path, file), 'r', encoding='UTF-8') as f:
                data = json.load(f)
            problem = data.get("question", "Not Available")
            orig_proof = data.get("orig_solution", "Not Available")
            new_proof = data.get("new_solution", "Not Available")
            problem_index = int(re.search(r"problem_(\d+)", file).group(1)) # type: ignore
            proof_index = int(re.search(r"generate_(\d+)", file).group(1)) # type: ignore

            orig_answer = self.extract_boxed_answer(orig_proof)
            new_answer = self.extract_boxed_answer(new_proof)
            if orig_answer != "" and new_answer != "":
                if orig_answer.replace("{", "").replace("}", "") == new_answer.replace("{", "").replace("}", ""):
                    logger.debug(f"Skipping file {file} as the answers are the same: {orig_answer}")
                    data['pass_check'] = True
                    with open(os.path.join(load_path, file), 'w', encoding='UTF-8') as f:
                        json.dump(data, f, ensure_ascii=False, indent=4)
                    return

            response = await self._generate(problem, orig_proof, new_proof, problem_index, proof_index, shared_semaphore=semaphore)

            for j, res in enumerate(response):
                if res is None:
                    logger.warning(f"Response for problem {problem_index}, proof {proof_index} is None, skipping.")
                    continue
                (thinking, solution, pass_check) = res
                if not pass_check:
                    logger.warning(f"Response for problem {problem_index}, proof {proof_index} did not pass check, skipping.")
                    continue
                if isinstance(solution, str):
                    logger.warning(
                        f"Response for problem {problem_index}, proof {proof_index} is a string, expected JSON object."
                    )
                    continue
                result = solution["answer_same"]
                data['pass_check'] = result
                if result:
                    pass
                else:
                    incorrect_files.append(file)
                with open(os.path.join(load_path, file), 'w', encoding='UTF-8') as f:
                    json.dump(data, f, ensure_ascii=False, indent=4)

        if not async_mode:
            for file in tqdm(file_list, desc='Processing files'):
                await process_file(file)
        else:
            tasks = [process_file(file) for file in file_list]
            for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc='Processing files'):
                await task
        logger.info(f"Processed {len(file_list)} files, {lines - len(incorrect_files)} of them passed the check.")
        logger.info(f"Correct rate : {(lines - len(incorrect_files)) / lines * 100:.2f}%")

        os.makedirs(os.path.join(load_path, "incorrect"), exist_ok=True)
        for file in incorrect_files:
            with open(os.path.join(load_path, file), 'r', encoding='UTF-8') as f:
                data = json.load(f)
            with open(os.path.join(load_path,"incorrect", file), 'w', encoding='UTF-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=4)

        return incorrect_files

    def _check_json_format(self, json_object: Any, other_info: dict[str, Any] | None = None) -> None:
        """
        Check if the given dictionary is in the correct JSON format.
        The second element of the tuple should be a JSON string.
        """
        if not isinstance(json_object, dict):
            raise JsonFormatError(
                f"Expected a dictionary, got {type(json_object)}")
        if not all(isinstance(key, str) for key in json_object.keys()):
            raise JsonFormatError(
                "All keys in the dictionary should be strings")
        if set(json_object.keys()) != {
                "rationale", "answer_same"
        }:
            raise JsonFormatError(
                f"The dictionary should contain exactly these keys: `rationale`, `answer_same`, but got {json_object.keys()}"
            )

        if not isinstance(json_object["rationale"], str):
            raise JsonFormatError(
                f"Expected the value of key `rationale` to be a string, got {type(json_object['rationale'])}"
            )
        if not isinstance(json_object["answer_same"], bool):
            raise JsonFormatError(
                f"Expected the value of key `answer_same` to be a boolean, got {type(json_object["answer_same"])}"
            )

async def main(sys_argv: list[str] | None = None):
    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args
    base_parser = common_parse_args()
    check_parser = ArgumentParser(
        parents=[base_parser],
        description="Check solutions for problems using a language model."
    )
    args = check_parser.parse_args(sys_argv)

    generator = JsonFormatProblemProofPairGenerator(
        provider=args.provider,
        model=args.model
    )
    incorrect_files = await generator.process(
        args.file, 
        lines=args.lines, 
        num_worker=args.num_worker,
        resume=args.resume,
        async_mode=not args.no_async,
        batch_id=args.batch_id,
    )
    logger.info("Finished processing all problems.")
    if len(incorrect_files)>0:
        logger.debug(f"Files with incorrect answers: {', '.join(incorrect_files)}")
    else:
        logger.debug("All files passed the check.")

if __name__ == "__main__":
    asyncio.run(main())