from pathlib import Path
import asyncio

from typing import Callable, Any
from olym_gen.utils.utils import get_logger, get_generator_base

from tqdm import tqdm

from typing import Sequence

from olym_gen.generator.base_generator import GeneratorBase
from olym_gen.generator.problem_proof_generator import (
    ProblemProofDataSample,
    ResumeMixin,
    SystemPromptMixin,
    LogProblemMixin,
    ProblemProofSaveMixin,
)

logger = get_logger()


class SolutionGenerator(
    SystemPromptMixin,
    ResumeMixin,
    ProblemProofSaveMixin,
    LogProblemMixin,
):
    """
    Given a problem, generate some proof.
    """
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        # Use composition instead of inheritance for GeneratorBase
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")

    @property
    def system_prompt_file(self) -> str:
        return "prompts/solution_system_prompt.txt"

    def _user_prompt(self, problem: str) -> str:
        return f"Problem: {problem}\n\n"

    @property
    def _default_save_path(self) -> str:
        return "save/solution/data"

    @property
    def proof_retrieve(self) -> Callable[[dict[str, Any]], list[str]]:
        return lambda d: [d["solution"][0]]

    def log_step_start(self, problem_index: int) -> str:
        return f"Trying to generate the proof for problem {problem_index} ..."

    def log_step_finish(self, problem_index: int) -> str:
        return f"Finish generating the proof for problem {problem_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will generate proofs for the problems in the file without reference."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem pairs with {num_returns} returns each. The thinking process and generated solutions are saved to {save_path}."
    
    @property
    def proof_retrieve_with_answer(self) -> Callable[[dict[str, Any]], list[str]]:
        return lambda d: [f"{solution}\nAnswer: \\boxed{{{d.get('answer', '')}}}" for solution in d["solution"]]

    async def _generate(
        self,
        problem: str,
        problem_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 8_192,
    ) -> list[tuple[str, str] | None]:
        """
        Generate thinking steps and final proof for a given problem using the DeepSeek Reasoner model.
        """
        logger.debug(self.log_step_start(problem_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(problem)
        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            max_tokens=max_tokens,
        )

        logger.debug(self.log_step_finish(problem_index))

        return return_list

    async def process(
        self,
        file: str,
        lines: int | None = None,
        indexes: Sequence[int] | None = None,
        num_returns: int = 1,
        save_path: str | None = None,
        num_worker: int = 1,
        resume: bool = False,
        max_tokens: int = 8_192,
        async_mode: bool = True,
        only_solve_id_conflict: bool = False,
        batch_id: str | None = None,
    ):
        self.batch_id = batch_id
        save_path = save_path if save_path is not None else self._default_save_path
        logger.info(
            f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem."
        )
        pairs = self.load_problems_and_proofs(
            Path(file),
            lines=lines,
            indexes=indexes,
        )

        semaphore = asyncio.Semaphore(num_worker)

        async def process_pair(pair: ProblemProofDataSample):
            problem, proof, field, problem_idx, _ = (
                pair["problem"],
                pair["proof"],
                pair["field"],
                pair["problem_index"],
                pair["proof_index"],
            )

            need_generation = self.need_generation_idx(
                resume=resume,
                problem=problem,
                answer=proof,
                problem_index=problem_idx,
                proof_index=None,
                save_path=save_path,
                num_returns=num_returns,
                pairs=pairs,
            )

            if only_solve_id_conflict:
                return

            if not need_generation:
                logger.info(
                    f"Problem {problem_idx} already generated all {num_returns} proofs, skipping."
                )
                return

            response = await self._generate(
                problem,
                shared_semaphore=semaphore,
                num_returns=len(need_generation),
                problem_index=problem_idx,
                max_tokens=max_tokens,
            )

            self.save_response(
                input_data=pair,
                response=response,
                save_path=save_path,
                generation_index=need_generation,
                without_proof=True,
            )

        if not async_mode:
            for pair in tqdm(pairs, desc="Processing pairs"):
                await process_pair(pair)
        else:
            tasks = [process_pair(pair) for pair in pairs]
            for task in tqdm(
                asyncio.as_completed(tasks), total=len(tasks), desc="Processing pairs"
            ):
                await task

        logger.info(self.log_finish(file, len(pairs), num_returns, save_path))


async def main(sys_argv: list[str] | None = None):

    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args

    base_parser = common_parse_args()
    proof_parser = ArgumentParser(
        parents=[base_parser],
        description="Generate proofs for problems using a specified model.",
    )
    proof_parser.set_defaults(max_tokens=8_192)
    args = proof_parser.parse_args(sys_argv)

    logger.info(
        "Using SolutionGenerator to generate solutions without reference proofs."
    )
    generator = SolutionGenerator(provider=args.provider, model=args.model)
    await generator.process(
        args.file,
        lines=args.lines,
        indexes=args.indexes if args.indexes else None,
        num_returns=args.num_returns,
        save_path=args.save_path,
        num_worker=args.num_worker,
        resume=args.resume,
        async_mode=not args.no_async,
        only_solve_id_conflict=args.only_solve_id_conflict,
        max_tokens=args.max_tokens,
        batch_id=args.batch_id,
    )
    logger.info("Finished processing all problems.")


if __name__ == "__main__":
    asyncio.run(main())
