from pathlib import Path
from functools import cached_property
import json
import os
import asyncio

from typing import Callable, Any, TypedDict
from olym_gen.utils.utils import get_logger

from tqdm import tqdm

from olym_gen.generator.problem_proof_generator import ProblemProofGeneratorBase


logger = get_logger()


class RephraseGenerator(ProblemProofGeneratorBase):
    """
    Given a problem and its groundtruth proof, generate a rephrased proof that is semantically equivalent to the original proof.
    """

    @property
    def system_prompt_file(self) -> str:
        return "prompts/rephrase_system_prompt.txt"

    @property
    def _default_save_path(self) -> str:
        return "save/rephrase"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return (
            f"Rephrasing proof for problem {problem_index} and proof {proof_index}..."
        )

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished rephrasing proof for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will rephrase the proofs for the problems in the file."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem-proof pairs with {num_returns} returns each. The thinking process and rephrased proofs are saved to {save_path}."
    

class AugmentGenerator(ProblemProofGeneratorBase):
    """
    Given a problem and its groundtruth proof, generate a rephrased proof that is semantically equivalent to the original proof, while keeping the original style. Used to augment the dataset.
    """
    
    @property
    def system_prompt_file(self) -> str:
        return "prompts/augment_system_prompt.txt"

    @property
    def _default_save_path(self) -> str:
        return "save/augment"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return (
            f"Rephrasing proof for problem {problem_index} and proof {proof_index} while keeping the original style..."
        )

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished rephrasing proof for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will rephrase the proofs for the problems in the file while keeping the original style."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem-proof pairs with {num_returns} returns each. The thinking process and rephrased proofs are saved to {save_path}."
    
class TranslateGenerator(ProblemProofGeneratorBase):
    """
    Given a problem and its groundtruth proof, generate a translated proof that is semantically equivalent to the original proof.
    """

    @property
    def system_prompt_file(self) -> str:
        return "prompts/translate_system_prompt.txt"

    @property
    def _default_save_path(self) -> str:
        return "save/translate"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Translating proof for problem {problem_index} and proof {proof_index}..."
    
    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished translating proof for problem {problem_index} and proof {proof_index}."
    
    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will translate the proofs for the problems in the file."
    
    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem-proof pairs with {num_returns} returns each. The thinking process and translated proofs are saved to {save_path}."

class DisturbGenerator(ProblemProofGeneratorBase):
    """
    Given a problem and its groundtruth proof, generate a disturbed proof that is incorrect.
    """

    @property
    def system_prompt_file(self) -> str:
        return "prompts/disturb_system_prompt.txt"

    @property
    def _default_save_path(self) -> str:
        return "save/disturb"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return (
            f"Disturbing proof for problem {problem_index} and proof {proof_index}..."
        )

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished disturbing proof for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will disturb the proofs into an incorrect proofs for the problems."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Generated {num_pairs} problem-proof pairs with {num_returns} returns each. The thinking process and disturbed incorrect proofs are saved to {save_path}."


async def main(sys_argv: list[str] | None = None):
    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args

    base_parser = common_parse_args()
    rephrase_parser = ArgumentParser(
        parents=[base_parser],
        description="Generate rephrased solutions for problems using a language model.",
    )
    rephrase_parser.add_argument(
        "--disturb",
        action="store_true",
        help="If set, generate disturbed solutions instead of rephrased solutions.",
    )
    rephrase_parser.add_argument(
        "--keep_style",
        action="store_true",
        help="If set, keep the original style of the proof when rephrasing.",
    )
    
    rephrase_parser.add_argument(
        "--translate",
        action="store_true",
        help="If set, generate translated solutions instead of rephrased solutions.",
    )

    args = rephrase_parser.parse_args(sys_argv)
    
    if sum([args.disturb, args.keep_style, args.translate]) > 1:
        raise ValueError("Only one of --disturb, --keep_style and --translate can be set.")
    
    if args.disturb:
        logger.info("Generating disturbed solutions.")
        generator = DisturbGenerator(provider=args.provider, model=args.model)
    elif args.keep_style:
        logger.info("Generating style-preserved rephrased solutions.")
        generator = AugmentGenerator(provider=args.provider, model=args.model)
    elif args.translate:
        logger.info("Generating translated solutions.")
        generator = TranslateGenerator(provider=args.provider, model=args.model)
    else:
        logger.info("Generating rephrased solutions.")
        generator = RephraseGenerator(provider=args.provider, model=args.model)

    await generator.process(
        args.file,
        lines=args.lines,
        indexes=args.indexes if args.indexes else None,
        num_returns=args.num_returns,
        save_path=args.save_path,
        num_worker=args.num_worker,
        resume=args.resume,
        max_tokens=args.max_tokens,
        async_mode=not args.no_async,
        only_solve_id_conflict=args.only_solve_id_conflict,
        batch_id=args.batch_id,
    )
    logger.info("Finished processing all problems.")

if __name__ == "__main__":
    asyncio.run(main())
