from pathlib import Path
import json
import os
import asyncio

from typing import Callable, Any
from olym_gen.utils.utils import get_logger, get_generator_base

from tqdm import tqdm

from typing import Sequence

from olym_gen.generator.problem_proof_generator import (
    ProblemProofDataSample,
    ProblemProofGeneratorBase,
    ResumeMixin,
    SystemPromptMixin,
    LogProblemMixin,
    ProblemProofSaveMixin,
    GeneratorBase,
)

logger = get_logger()


class ProofGenerator(
    SystemPromptMixin,
    ResumeMixin,
    ProblemProofSaveMixin,
    LogProblemMixin,
):
    """
    Given a problem, generate some proof.
    """
    
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        # Use composition instead of inheritance for GeneratorBase
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")
    
    async def single_turn_request(self, *args, **kwargs):
        """Delegate to the composed generator_base"""
        return await self.generator_base.single_turn_request(*args, **kwargs)
    
    @property
    def batch_id(self):
        """Delegate to the composed generator_base"""
        return self.generator_base.batch_id if self.generator_base else None
    
    @batch_id.setter  
    def batch_id(self, value):
        """Delegate to the composed generator_base"""
        if self.generator_base:
            self.generator_base.batch_id = value

    @property
    def system_prompt_file(self) -> str:
        return "prompts/proof_system_prompt.txt"

    def _user_prompt(self, problem: str) -> str:
        return f"Problem: {problem}\n\n"

    @property
    def _default_save_path(self) -> str:
        return "save/proof"

    @property
    def proof_retrieve(self) -> Callable[[dict[str, Any]], list[str]]:
        return lambda d: [d["solution"][0]]

    def log_step_start(self, problem_index: int) -> str:
        return f"Trying to generate the proof for problem {problem_index} ..."

    def log_step_finish(self, problem_index: int) -> str:
        return f"Finish generating the proof for problem {problem_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will generate proofs for the problems in the file without reference."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem pairs with {num_returns} returns each. The thinking process and generated solutions are saved to {save_path}."

    async def _generate(
        self,
        problem: str,
        problem_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> list[tuple[str, str] | None]:
        """
        Generate thinking steps and final proof for a given problem using the DeepSeek Reasoner model.
        """
        logger.debug(self.log_step_start(problem_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(problem)
        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            max_tokens=max_tokens,
            extra_model_paras=extra_model_paras,
        )

        logger.debug(self.log_step_finish(problem_index))

        return return_list

    async def process(
        self,
        file: str,
        lines: int | None = None,
        indexes: Sequence[int] | None = None,
        num_returns: int = 1,
        save_path: str | None = None,
        num_worker: int = 1,
        resume: bool = False,
        max_tokens: int = 64_000,
        async_mode: bool = True,
        only_solve_id_conflict: bool = False,
        batch_id: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ):
        self.batch_id = batch_id
        save_path = save_path if save_path is not None else self._default_save_path
        logger.info(
            f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem."
        )
        pairs = self.load_problems_and_proofs(
            Path(file),
            lines=lines,
            indexes=indexes,
        )

        semaphore = asyncio.Semaphore(num_worker)

        async def process_pair(pair: ProblemProofDataSample):
            problem, proof, field, problem_idx, _ = (
                pair["problem"],
                pair["proof"],
                pair["field"],
                pair["problem_index"],
                pair["proof_index"],
            )

            need_generation = self.need_generation_idx(
                resume=resume,
                problem=problem,
                answer=proof,
                problem_index=problem_idx,
                proof_index=None,
                save_path=save_path,
                num_returns=num_returns,
                pairs=pairs,
            )

            if only_solve_id_conflict:
                return

            if not need_generation:
                logger.info(
                    f"Problem {problem_idx} already generated all {num_returns} proofs, skipping."
                )
                return

            response = await self._generate(
                problem,
                shared_semaphore=semaphore,
                num_returns=len(need_generation),
                problem_index=problem_idx,
                max_tokens=max_tokens,
                extra_model_paras=extra_model_paras,
            )

            self.save_response(
                input_data=pair,
                response=response,
                save_path=save_path,
                generation_index=need_generation,
                without_proof=True,
            )

        if not async_mode:
            for pair in tqdm(pairs, desc="Processing pairs"):
                await process_pair(pair)
        else:
            tasks = [process_pair(pair) for pair in pairs]
            for task in tqdm(
                asyncio.as_completed(tasks), total=len(tasks), desc="Processing pairs"
            ):
                await task

        logger.info(self.log_finish(file, len(pairs), num_returns, save_path))


class ReferenceProofGenerator(ProblemProofGeneratorBase):
    """
    Given a problem and one of its groundtruth proofs, generate a new proof that is semantically equivalent to the original proof.
    """

    @property
    def system_prompt_file(self) -> str:
        return "prompts/reference_proof_system_prompt.txt"

    def _user_prompt(self, problem: str, proof: str) -> str:
        return f"Problem: {problem}\n\nReference Proof: {proof}\n\n"

    @property
    def _default_save_path(self) -> str:
        return "save/reference_proof"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Trying to generate the proof for problem {problem_index} with reference proof {proof_index} ..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finish generating the proof for problem {problem_index} with reference proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem using ReferenceProofGenerator. This will generate new proofs based on the reference proofs provided in the dataset."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem-proof pairs with {num_returns} returns.\n Saved the generated thinking steps and new solutions to {save_path}."


async def main(sys_argv: list[str] | None = None):

    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args

    base_parser = common_parse_args()
    proof_parser = ArgumentParser(
        parents=[base_parser],
        description="Generate proofs for problems using a specified model.",
    )
    proof_parser.add_argument(
        "--ref",
        action="store_true",
        help="Use the reference proof as the ground truth for generation. If not set, no reference proof is used.",
    )
    args = proof_parser.parse_args(sys_argv)

    if args.ref:
        logger.info(
            "Using ReferenceProofGenerator to generate proofs with reference proofs."
        )
        generator = ReferenceProofGenerator(provider=args.provider, model=args.model)
    else:
        logger.info(
            "Using ProofGenerator to generate proofs without reference proofs. This is the default behavior."
        )
        generator = ProofGenerator(provider=args.provider, model=args.model)
    await generator.process(
        args.file,
        lines=args.lines,
        indexes=args.indexes if args.indexes else None,
        num_returns=args.num_returns,
        save_path=args.save_path,
        num_worker=args.num_worker,
        resume=args.resume,
        max_tokens=args.max_tokens,
        async_mode=not args.no_async,
        only_solve_id_conflict=args.only_solve_id_conflict,
        batch_id=args.batch_id,
        extra_model_paras=args.extra_model_paras,
    )
    logger.info("Finished processing all problems.")


if __name__ == "__main__":
    asyncio.run(main())
