import json
from typing import Any, Callable
from olym_gen.utils.utils import get_logger, UNKNOWN_FIELD

import asyncio

from olym_gen.generator.json_format_generator import (
    JsonFormatProblemProofGeneratorBase,
    JsonFormatError,
)
from olym_gen.generator.problem_proof_generator import ProblemProofGeneratorBase

logger = get_logger()

class CheckMixin:
    """Given a problem and one proof, check the correctness of the proof."""

    @property
    def _save_orig_solution_name(self) -> str:
        return "checked_proof"

    @property
    def _save_solution_name(self) -> str:
        return "check_result"

    @property
    def problem_retrieve(self) -> Callable[[dict[str, Any]], str]:
        return lambda d: d["question"]

    @property
    def proof_retrieve(self) -> Callable[[dict[str, Any]], list[str]]:
        return lambda d: [str(solution) for solution in d["proofs"]]

    @property
    def field_retrieve(self) -> Callable[[dict[str, Any]], str]:
        return lambda d: d.get("field", UNKNOWN_FIELD)

    @property
    def sample_filter(self) -> Callable[[dict[str, Any]], bool]:
        return lambda d: True

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Checking proof for problem {problem_index} and proof {proof_index}..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished checking proof for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will check the proofs for the problems in the file."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Verified {num_pairs} problem-proof pairs with {num_returns} checks each. The thinking process and checking results are saved to {save_path}."


class CheckGeneratorAPI(
    JsonFormatProblemProofGeneratorBase[dict[str, bool]],
    CheckMixin,
):
    """Given a problem and one proof, check the correctness of the proof."""

    @property
    def system_prompt_file(self) -> str:
        return "prompts/check_system_prompt.txt"

    def _user_prompt(self, problem: str, proof: str) -> str:
        return f"Problem: {problem}\n\nProof need to check: {proof}\n\n"

    @property
    def _default_save_path(self) -> str:
        return "save/check"

    def _check_json_format(
        self, json_object: Any, other_info: dict[str, Any] | None = None
    ) -> None:
        """
        Check if the given dictionary is in the correct JSON format.
        The second element of the tuple should be a JSON string.
        """
        if not isinstance(json_object, dict):
            raise JsonFormatError(f"Expected a dictionary, got {type(json_object)}")
        if not all(isinstance(key, str) for key in json_object.keys()):
            raise JsonFormatError("All keys in the dictionary should be strings")
        if set(json_object.keys()) != {
            "condition1_satisfied",
            "condition2_satisfied",
            "condition3_satisfied",
            "condition4_satisfied",
            "proof_correct",
        }:
            raise JsonFormatError(
                f"The dictionary should contain exactly these keys: `condition1_satisfied`,`condition2_satisfied`, `condition3_satisfied`, `condition4_satisfied` and `proof_correct`, but got {json_object.keys()}"
            )
        for key in json_object:
            if not isinstance(json_object[key], bool):
                raise JsonFormatError(
                    f"Expected the value of key `{key}` to be a boolean, got {type(json_object[key])}"
                )
        if json_object["proof_correct"] and not all(
            json_object[key]
            for key in [
                "condition1_satisfied",
                "condition2_satisfied",
                "condition3_satisfied",
                "condition4_satisfied",
            ]
        ):
            raise JsonFormatError(
                "`proof_correct` is True, but not all conditions are satisfied."
            )
            
class CheckGeneratorBaseline(CheckMixin, ProblemProofGeneratorBase):
    """
    Given a problem and one proof, check the correctness of the proof.
    This is a version to test the baseline and our trained model that does not use JSON format. The prompt could be slightly different from the JSON format version.
    """
    
    @property
    def system_prompt_file(self) -> str:
        return "should not be used"
    
    @property
    def _system_prompt(self) -> str:
        return ""

    def _user_prompt(self, problem: str, proof: str) -> str:
        with open("prompts/check_system_prompt_direct.txt", "r") as f:
            system_prompt = f.read().strip() + '\n'
        return system_prompt + f"<question>\n{problem}\n</question>\n\n<proof>\n{proof}\n</proof>\n\nNote that your output should follow the **Output Format** above."

    @property
    def _default_save_path(self) -> str:
        return "save/check_baseline"

    @classmethod
    def answer_retrieve(cls, text_answer: str) -> str:
        """
        Expected answer `true` or `false`.
        """
        ans = text_answer.split('</think>')[-1].strip()
        ans = ans.split('###')[-1].strip().lower()
        return ans


async def main(sys_argv: list[str] | None = None):
    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args

    base_parser = common_parse_args()
    check_parser = ArgumentParser(
        parents=[base_parser],
        description="Check solutions for problems using a language model.",
    )
    check_parser.add_argument("--local_model", action="store_true", help="Use local model, it avoid use the json format and provide a more direct prompt.")
    args = check_parser.parse_args(sys_argv)

    if args.local_model:
        logger.info("Use local model for checking.")
        generator = CheckGeneratorBaseline(provider=args.provider, model=args.model)
    else:
        logger.info("Use API model for checking with JSON format.")
        generator = CheckGeneratorAPI(provider=args.provider, model=args.model)
    await generator.process(
        args.file,
        lines=args.lines,
        indexes=args.indexes if args.indexes else None,
        num_returns=args.num_returns,
        save_path=args.save_path,
        num_worker=args.num_worker,
        resume=args.resume,
        max_tokens=args.max_tokens,
        async_mode=not args.no_async,
        only_solve_id_conflict=args.only_solve_id_conflict,
        batch_id=args.batch_id,
    )
    logger.info("Finished processing all problems.")


if __name__ == "__main__":
    asyncio.run(main())
