from olym_gen.generator.json_format_generator import (
    JsonFormatProblemProofGeneratorBase,
    JsonFormatError,
)
from olym_gen.generator.problem_proof_generator import ProblemProofDataSample

from tqdm import tqdm

from olym_gen.utils.utils import get_logger, UNKNOWN_FIELD
from olym_gen.utils.sample_utils import retrieve_id_from_name

import asyncio
import random
import itertools

from typing import TypedDict, Any, Sequence, Callable, Literal
from pathlib import Path
import json
import os
import re

StepTypedDict = TypedDict(
    "StepTypedDict",
    {
        "step_index": int,
        "content": str,
    },
)

SplitStepsTypedDict = TypedDict(
    "SplitStepsTypedDict",
    {
        "steps": list[StepTypedDict],
    },
)

logger = get_logger()


class ProofSplitter(JsonFormatProblemProofGeneratorBase[SplitStepsTypedDict]):
    """
    Given a problem and a proof, split the proof into smaller parts.
    """

    @property
    def system_prompt_file(self) -> str:
        return "prompts/split_proof_system_prompt.txt"

    def _user_prompt(self, problem: str, proof: str) -> str:
        return f"Problem: {problem}\n\nProof to split: {proof}\n\n"

    @property
    def _default_save_path(self) -> str:
        return "save/split_proof"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Splitting proof for problem {problem_index} and proof {proof_index}..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finishing splitting proof for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will split the proofs into steps for the problems in the file."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Split {num_pairs} problem-proof pairs with {num_returns} times each. The split proofs are saved to {save_path}."

    @property
    def _save_solution_name(self) -> str:
        return "steps"

    def _check_json_format(
        self, json_object: Any, other_info: dict[str, Any] | None = None
    ) -> None:
        if not isinstance(json_object, dict):
            raise JsonFormatError(
                f"The input should be a dictionary but got {type(json_object)}."
            )
        if list(json_object.keys()) != ["steps"]:
            raise JsonFormatError(
                f"The dictionary should have a single key 'steps' containing a list of proof steps, but got keys: {list(json_object.keys())}."
            )
        steps = json_object["steps"]
        if not isinstance(steps, list):
            raise JsonFormatError(
                f"The value of 'steps' should be a list but got {type(steps)}."
            )
        for step in steps:
            if not isinstance(step, dict):
                raise JsonFormatError(
                    f"Each step should be a dictionary but got {type(step)}."
                )
            if list(step.keys()) != ["step_index", "content", "category"]:
                raise JsonFormatError(
                    "Each step should have keys 'step_index', 'content' and 'category'."
                )
            if not isinstance(step["step_index"], int) or step["step_index"] < 0:
                raise JsonFormatError(
                    f"Step index should be a non-negative integer but got {step['step_index']}."
                )
            if not isinstance(step["content"], str):
                raise JsonFormatError(
                    f"Step content should be a string but got {type(step['content'])}."
                )
            if not isinstance(step["category"], str):
                raise JsonFormatError(
                    f"Step category should be a string but got {type(step['category'])}."
                )
            if not step["category"] in [
                "leading",
                "supporting",
                "conclusion",
                "uncertain",
            ]:
                raise JsonFormatError(
                    f"Step category should be one of 'leading', 'supporting', 'conclusion' or 'uncertain' but got {step['category']}."
                )
            if (
                step["category"] == "conclusion"
                and step["step_index"] != len(steps) - 1
            ):
                raise JsonFormatError(
                    f"Conclusion step should be the last step but got step index {step['step_index']} in a proof with {len(steps)} steps."
                )


class SplitProofDataSample(TypedDict):
    problem: str  # The problem statement.
    splitted_proof: list[tuple[str, str]]  # The proof split into steps.
    problem_index: int  # The index of the problem in the dataset.
    proof_index: int  # The index of the proof in the dataset.
    field: str  # The field of the problem, e.g., "Geometry", "


class MaskedProblemProofDataSample(ProblemProofDataSample):
    """
    A masked problem-proof data sample, which includes the problem, proof with some parts masked, and the field.
    """

    splitted_proof: list[tuple[str, str]]
    masked_steps: list[int]  # The indices of the steps that are masked in the proof.


MaskMethod = Literal[
    "last", "select_one", "select_some", "select_one_leading", "select_some_leading"
]


class ProofMasker:

    def __init__(self, mask_token: str = "[MASKED_PROOF]"):
        self.mask_token = mask_token
        self._default_save_path = "save/masked_proof"

    def _build_masked_proof(self, steps: list[str], masked_steps: list[int]) -> str:
        """
        Build a masked proof by merging continuous masked steps with mask tokens.
        """
        masked_proof_parts = []
        i = 0

        last_step_is_masked = False
        for i in range(len(steps)):
            if i in masked_steps and last_step_is_masked:
                continue  # Skip already masked steps, keep the `last_step_is_masked` as True
            elif i in masked_steps and not last_step_is_masked:
                # Add mask token for this continuous group of masked steps
                masked_proof_parts.append(self.mask_token)
                last_step_is_masked = True
            elif not i in masked_steps:
                # Skip all continuous masked steps
                masked_proof_parts.append(steps[i])
                last_step_is_masked = False
            else:
                assert False

        return "\n\n".join(masked_proof_parts)

    def _create_masked_sample(
        self,
        sample: SplitProofDataSample,
        steps: list[tuple[str, str]],
        masked_steps: list[int],
    ) -> MaskedProblemProofDataSample:
        """
        Create a masked problem proof data sample.
        """
        assert masked_steps and len(masked_steps) <= len(
            sample["splitted_proof"]
        ), f"Invalid masked_steps {masked_steps} for problem {sample['problem_index']} and proof {sample['proof_index']}. Steps: {steps}."
        steps_without_category = [step[0] for step in steps]
        masked_proof = self._build_masked_proof(steps_without_category, masked_steps)
        return {
            "problem": sample["problem"],
            "splitted_proof": steps,
            "proof": masked_proof,
            "masked_steps": masked_steps,
            "problem_index": sample["problem_index"],
            "proof_index": sample["proof_index"],
            "field": sample["field"],
        }

    def load(
        self,
        file_path: Path,
        problem_retrieve: Callable[[dict[str, Any]], str] | None = None,
        splitted_proof_retrieve: (
            Callable[[dict[str, Any]], list[tuple[str, str]]] | None
        ) = None,
        field_retrieve: Callable[[dict[str, Any]], str] | None = None,
        sample_filter: Callable[[dict[str, Any]], bool] | None = None,
    ) -> list[SplitProofDataSample]:

        if problem_retrieve is None:
            problem_retrieve = lambda x: str(x.get("question", ""))
        if splitted_proof_retrieve is None:
            splitted_proof_retrieve = lambda x: [
                (str(step["content"]), str(step["category"]))
                for step in x["steps"]["steps"]
            ]
        if field_retrieve is None:
            field_retrieve = lambda x: str(x.get("field", ""))
        if sample_filter is None:
            sample_filter = lambda d: (
                "geometry" not in str(d.get("field", "")).lower()
                and bool(d.get("pass_check", False))
            )

        return_list: list[SplitProofDataSample] = []

        for file_name in file_path.glob("*.json"):
            problem_index, proof_index, generation_id = retrieve_id_from_name(
                file_name.name
            )
            assert (
                generation_id == 0
            ), f"Expected generation_id to be 0 but got {generation_id}."
            with open(file_name, "r", encoding="utf-8") as f:
                data = json.load(f)
                if not sample_filter(data):
                    continue
                problem = problem_retrieve(data)
                proof_steps: list[tuple[str, str]] = splitted_proof_retrieve(data)
                field = (
                    field_retrieve(data) if field_retrieve else data.get("field", "")
                )

                return_list.append(
                    {
                        "problem": problem,
                        "splitted_proof": proof_steps,
                        "problem_index": problem_index,
                        "proof_index": proof_index,
                        "field": field,
                    }
                )
        return return_list

    def process_one(
        self,
        sample: SplitProofDataSample,
        method: MaskMethod,
        prob: float = 0.3,
        random_seed: int | None = None,
        return_num: int | None = None,
    ) -> list[MaskedProblemProofDataSample]:
        """
        Mask and then merge the splitted proof steps into a masked proof.
        If the method is 'select_some' or 'select_one', the first and last steps are always included in the masked proof in as intended.
        """
        return_list: list[MaskedProblemProofDataSample] = []
        steps = sample["splitted_proof"]
        # if len(steps) <= 2:
        #     logger.warning(f"Proof for problem {sample['problem_index']} and proof {sample['proof_index']} has only {len(steps)} steps, skipping masking.")
        #     return return_list  # No masking needed for proofs with 2 or fewer steps.

        if method == "last":
            for num in range(len(steps) - 1, 1, -1):
                masked_steps = list(range(len(steps) - num, len(steps) - 1))
                return_list.append(
                    self._create_masked_sample(sample, steps, masked_steps)
                )

        elif method in ["select_one", "select_one_leading"]:
            if method == "select_one":
                possible_step_index = list(
                    range(1, len(steps) - 1)
                )  # Exclude first and last step
            elif method == "select_one_leading":
                possible_step_index = [
                    i for i, step in enumerate(steps) if step[1] == "leading"
                ]
            else:
                raise ValueError(f"Invalid method {method}")
            for num in possible_step_index:
                masked_steps = [num]
                return_list.append(
                    self._create_masked_sample(sample, steps, masked_steps)
                )

        elif method in ["select_some", "select_some_leading"]:
            if return_num is None:
                raise ValueError(
                    "When using 'select_some' or 'select_some_leading' method, return_num must be specified."
                )

            if method == "select_some_leading":
                possible_step_index = [
                    i for i, step in enumerate(steps) if step[1] == "leading"
                ]
            elif method == "select_some":
                possible_step_index = range(1, len(steps) - 1)
            else:
                raise ValueError(f"Invalid method {method}")

            # Calculate total possible combinations

            total_combinations = 2 ** (len(possible_step_index))

            if return_num >= total_combinations:

                # Generate all possible combinations of middle indices (including empty set)
                for r in range(len(possible_step_index) + 1):
                    for combo in itertools.combinations(possible_step_index, r):
                        masked_steps = list(combo)
                        if not masked_steps:
                            continue  # Skip empty combinations
                        if len(masked_steps) == len(steps):
                            logger.warning(
                                f"All steps are masked for problem {sample['problem_index']} and proof {sample['proof_index']}. It could lead to unexpected behavior."
                            )
                        return_list.append(
                            self._create_masked_sample(sample, steps, masked_steps)
                        )
            else:
                # Use random generation approach
                rng = random.Random(random_seed)
                mask_set = set()

                while len(return_list) < return_num:
                    keep = [rng.random() > prob for _ in possible_step_index]
                    masked_steps = [
                        idx for idx, k in zip(possible_step_index, keep) if not k
                    ]

                    # Skip if no steps are masked
                    if not masked_steps:
                        continue

                    if tuple(masked_steps) in mask_set:
                        continue

                    mask_set.add(tuple(masked_steps))
                    return_list.append(
                        self._create_masked_sample(sample, steps, masked_steps)
                    )
        else:
            raise ValueError(
                f"Unknown mask method: {method}. Supported methods are 'last', 'select_one', and 'select_some'."
            )

        return return_list

    def process(
        self,
        load_path: str | Path = "save/split_proof",
        method: MaskMethod = "select_some",
        prob: float = 0.3,
        random_seed: int | None = None,
        return_num: int | None = None,
        save_path: str | None = None,
    ):

        data = self.load(
            file_path=Path(load_path),
        )
        save_path = save_path if save_path else Path(self._default_save_path) / Path(method)

        for i, item in enumerate(data):
            if (
                not isinstance(item, dict)
                or "problem" not in item
                or "splitted_proof" not in item
            ):
                raise ValueError(
                    f"Invalid data format: {item}. Expected a dictionary with 'problem' and 'splitted_proof' keys."
                )

            returned = self.process_one(
                sample=item,
                method=method,
                prob=prob,
                random_seed=(
                    (random_seed + 132 * i) if random_seed is not None else None
                ),
                return_num=return_num,
            )

            problem_id, proof_id = item["problem_index"], item["proof_index"]
            for j, masked_sample in enumerate(returned):
                save_file = (
                    Path(save_path)
                    / f"problem_{problem_id}_proof_{proof_id}_generate_{j}.json"
                )
                save_file.parent.mkdir(parents=True, exist_ok=True)
                with open(save_file, "w", encoding="utf-8") as f:
                    json.dump(masked_sample, f, ensure_ascii=False, indent=4)

    @staticmethod
    def from_json_to_jsonl(read_path: str, save_path: str):
        jsonl_dict = {}
        for file_name in os.listdir(read_path):
            try:
                problem_index, proof_index, generation_index = retrieve_id_from_name(
                    file_name
                )
            except ValueError:
                continue

            with open(os.path.join(read_path, file_name), "r", encoding="utf-8") as f:
                data = json.load(f)

            if problem_index not in jsonl_dict:
                jsonl_dict[problem_index] = {}
            if proof_index not in jsonl_dict[problem_index]:
                # Initialize the problem entry if it doesn't exist
                jsonl_dict[problem_index][proof_index] = {
                    "question": data["problem"],
                    "orig_proofs": data["splitted_proof"],
                    "masked_proofs": [],
                    "field": data.get("field", ""),
                    "total_steps": len(data["splitted_proof"]),
                    "masked_steps": [],
                    "problem_index": problem_index,
                    "proof_index": proof_index,
                    "source": [],
                }
            jsonl_dict[problem_index][proof_index]["masked_proofs"].append(
                data["proof"]
            )
            jsonl_dict[problem_index][proof_index]["masked_steps"].append(
                data["masked_steps"]
            )
            jsonl_dict[problem_index][proof_index]["source"].append(file_name)

        # Write to JSONL file
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, "w", encoding="utf-8") as f:
            for problem_index, content in sorted(
                jsonl_dict.items(), key=lambda x: x[0]
            ):
                for proof_index, proof_content in sorted(
                    content.items(), key=lambda x: x[0]
                ):
                    content = {
                        "question": proof_content["question"],
                        "orig_proofs": proof_content["orig_proofs"],
                        "masked_proofs": proof_content["masked_proofs"],
                        "masked_steps": proof_content["masked_steps"],
                        "field": proof_content["field"],
                        "total_steps": proof_content["total_steps"],
                        "source": proof_content["source"],
                        "problem_index": proof_content["problem_index"],
                        "proof_index": proof_content["proof_index"],
                    }
                    f.write(json.dumps(content, ensure_ascii=False) + "\n")


class MaskedProofGenerator(JsonFormatProblemProofGeneratorBase):
    """
    Given a problem and a proof with some parts masked, try to complete the masked parts of the proof.
    """

    @property
    def system_prompt_file(self) -> str:
        return "prompts/masked_proof_completion_system_prompt.txt"

    def _user_prompt(self, problem: str, proof: str) -> str:
        return f"Problem: {problem}\n\nMasked proof to complete: {proof}\n\nThere are {proof.count('[MASKED_PROOF]')} [MASKED_PROOF] in the proof."

    @property
    def _default_save_path(self) -> str:
        return "save/masked_proof_completion"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Completing masked proof for problem {problem_index} and proof {proof_index}..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finishing completing masked proof for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will complete the masked proofs for the problems in the file."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Completed {num_pairs} masked proofs with {num_returns} times each. The completed proofs are saved to {save_path}."

    @property
    def problem_retrieve(self) -> Callable[[dict[str, Any]], str]:
        return lambda d: d["question"]

    @property
    def proof_retrieve(self) -> Callable[[dict[str, Any]], list[str]]:
        return lambda d: [str(solution) for solution in d["masked_proofs"]]

    @property
    def field_retrieve(self) -> Callable[[dict[str, Any]], str]:
        return lambda d: d.get("field", UNKNOWN_FIELD)

    @property
    def sample_filter(self) -> Callable[[dict[str, Any]], bool]:
        return lambda d: True

    @property
    def _save_solution_name(self) -> str:
        return "completion"

    @property
    def _save_orig_solution_name(self) -> str:
        return "masked_proof"

    def prepare_info_for_json_check(self, problem: str, proof: str) -> dict[str, Any]:
        return {"num_masked": proof.count("[MASKED_PROOF]")}

    def _check_json_format(
        self, json_object: Any, other_info: dict[str, Any] | None = None
    ) -> None:
        if not isinstance(json_object, dict):
            raise JsonFormatError(
                f"The input should be a dictionary but got {type(json_object)}."
            )
        if list(json_object.keys()) != ["completion"]:
            raise JsonFormatError(
                f"The dictionary should have a single key 'completed_proof' containing the completed proof, but got keys: {list(json_object.keys())}."
            )
        completion = json_object["completion"]
        if not isinstance(completion, list):
            raise JsonFormatError(
                f"The value of 'completed_proof' should be a list but got {type(completion)}."
            )
        steps = []
        for step in completion:
            if not isinstance(step, dict):
                raise JsonFormatError(
                    f"Each step in 'completed_proof' should be a dictionary but got {type(step)}."
                )
            if list(step.keys()) != ["completion_index", "completion"]:
                raise JsonFormatError(
                    "Each step in 'completed_proof' should have keys 'completion_index' and 'completion'."
                )
            if not isinstance(step["completion_index"], int):
                raise JsonFormatError(
                    f"Completion index should be an integer but got {type(step['completion_index'])}."
                )
            if not isinstance(step["completion"], str):
                raise JsonFormatError(
                    f"Completion content should be a string but got {type(step['completion'])}."
                )
            steps.append(int(step["completion_index"]))
        if set(steps) != set(range(len(steps))):
            raise JsonFormatError(
                f"Completion indices should be continuous from 0 to {len(steps) - 1} but got {', '.join(str(step['completion_index']) for step in steps)}."
            )
        if other_info is None or "num_masked" not in other_info:
            raise ValueError(
                f"Other info must contain 'num_masked' key to check the number of masked parts. But get {other_info}."
            )
        num_masked = other_info["num_masked"]
        assert (
            isinstance(num_masked, int) and num_masked > 0
        ), f"num_masked should be a positive integer but got {num_masked}."
        if num_masked != len(steps):
            raise JsonFormatError(
                f"The number of masked parts in the proof ({num_masked}) does not match the number of completed steps ({len(steps)})."
            )

    @staticmethod
    def replace_placeholder(
        root_dir: str, save_suffix: str = "_replace", source_dataset: str | None = None, save_dir: str | None = None
    ):
        if save_dir is None:
            save_dir = os.path.join(
                os.path.dirname(root_dir), os.path.basename(root_dir) + save_suffix
            )
        else:
            save_dir = os.path.abspath(save_dir)
        os.makedirs(save_dir, exist_ok=True)
        for file_name in tqdm(
            os.listdir(root_dir), desc="Replacing placeholders in masked proofs"
        ):
            if not file_name.endswith(".json"):
                continue
            with open(os.path.join(root_dir, file_name), "r", encoding="utf-8") as f:
                data = json.load(f)

            try:
                if not data["pass_check"]:
                    continue
                completion = data["completion"]["completion"]
                masked_proof = data["masked_proof"]
                masked_num = len(
                    re.findall(
                        r"\[MASKED_PROOF\]", masked_proof, re.DOTALL | re.MULTILINE
                    )
                )
                if masked_num != len(completion):
                    logger.error(
                        f"Masked proof {file_name} has {masked_num} masked parts but completion has {len(completion)} parts. Skipping this file."
                    )
                    continue
                for i, step in enumerate(completion):
                    assert isinstance(
                        step, dict
                    ), f"Each step in completion should be a dictionary but got {type(step)} in file {file_name}."
                    assert (
                        "completion_index" in step and "completion" in step
                    ), f"Each step in completion should have 'completion_index' and 'completion' keys but got {list(step.keys())} in file {file_name}."
                    assert isinstance(
                        step["completion_index"], int
                    ), f"Completion index should be an integer but got {type(step['completion_index'])} in file {file_name}."
                    assert isinstance(
                        step["completion"], str
                    ), f"Completion content should be a string but got {type(step['completion'])} in file {file_name}."
                    if (
                        step["completion_index"] < 0
                        or step["completion_index"] >= masked_num
                    ):
                        raise ValueError(
                            f"Completion index {step['completion_index']} out of range in file {file_name}. Expected range is 0 to {masked_num - 1}."
                        )
                    assert (
                        step["completion_index"] == i
                    ), f"Completion index {step['completion_index']} does not match expected index {i} in file {file_name}."
                    try:
                        masked_proof = re.sub(
                            r"\[MASKED_PROOF\]",
                            lambda m: step["completion"],
                            masked_proof,
                            count=1,
                            flags=re.DOTALL | re.MULTILINE,
                        )  # re.sub will see \begin as a escape sequence, so we use a lambda function to avoid this issue.
                    except re.error as e:
                        logger.error(
                            f"Regex error while replacing masked proof in file {file_name}: {e}"
                        )
                        raise e
                assert (
                    "[MASKED_PROOF]" not in masked_proof
                ), f"Masked proof {file_name} still contains unmasked parts after replacement."
                data["completed_proof"] = masked_proof
                if source_dataset is not None:
                    # try to get the problem index and proof index from the file name
                    problem_index, proof_index, generation_id = retrieve_id_from_name(
                        file_name
                    )
                    with open(source_dataset, "r", encoding="utf-8") as f:
                        # get the problem_index-th lines
                        lines = f.readlines()
                        q = json.loads(lines[problem_index])
                        assert (
                            q["question"] == data["question"]
                        ), f"Question in source dataset does not match the question in the file {file_name}."
                        data["groundtruth_proof"] = q["orig_proofs"]
                        masked_steps = q["masked_steps"][proof_index]
                        total_steps = q["total_steps"]
                        source = q["source"][proof_index]
                        data["masked_steps"] = masked_steps
                        data["total_steps"] = total_steps
                        data["source"] = source
            except Exception as e:
                logger.error(f"Error processing file {file_name}: {e}, skipping this file.")
                continue
            with open(
                os.path.join(save_dir, file_name), "w", encoding="utf-8"
            ) as wf:
                json.dump(data, wf, ensure_ascii=False, indent=4)


async def main(sys_argv: list[str] | None = None):
    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args

    base_parser = common_parse_args()
    mask_parser = ArgumentParser(
        parents=[base_parser],
        description="Masked proof related generation. Includes splitting proofs into steps by an LLM, masking and composing masked proofs and recovering masked proofs.",
    )
    mask_parser.add_argument(
        "--task",
        required=True,
        type=str,
        choices=["split", "mask", "recover"],
        help="The task to perform. 'split' will split proofs into steps, 'mask' will mask parts of proofs and provide masked proofs, and 'recover' will recover masked parts of proofs.",
    )
    mask_parser.add_argument(
        "--random_seed",
        type=int,
        default=None,
        help="The random seed for masking.",
    )
    mask_parser.add_argument(
        "--mask_method",
        type=str,
        default=None,
        choices=[
            "last",
            "select_one",
            "select_some",
            "select_one_leading",
            "select_some_leading",
        ],
        help="The method to use for masking. 'last' will mask the last few steps, 'select_one' will mask one step at a time, and 'select_some' will randomly select some steps to mask.",
    )
    mask_parser.add_argument(
        "--return_num",
        type=int,
        default=10,
        help="The number of masked proofs to return when masking. Only used when task is 'mask' and method is 'select_some'.",
    )
    mask_parser.add_argument(
        "--mask_prob",
        type=float,
        default=0.3,
        help="The probability of masking a step when using 'select_some' method. Only used when task is 'mask' ad the method is 'select_some'.",
    )
    args = mask_parser.parse_args(sys_argv)

    if args.task == "split":
        generator = ProofSplitter(provider=args.provider, model=args.model)
        await generator.process(
            file=args.file,
            lines=args.lines,
            indexes=args.indexes if args.indexes else None,
            num_worker=args.num_worker,
            num_returns=args.num_returns,
            save_path=args.save_path,
            resume=args.resume,
            max_tokens=args.max_tokens,
            async_mode=not args.no_async,
            only_solve_id_conflict=args.only_solve_id_conflict,
            batch_id=args.batch_id
        )
    elif args.task == "mask":
        generator = ProofMasker()
        mask_method = args.mask_method
        if mask_method is None:
            raise ValueError("mask_method must be specified when task is 'mask'.")
        generator.process(
            load_path=args.file,
            random_seed=args.random_seed,
            return_num=args.return_num,
            method=args.mask_method,
            prob=args.mask_prob,
            save_path=args.save_path,
        )
    elif args.task == "recover":
        generator = MaskedProofGenerator(provider=args.provider, model=args.model)
        await generator.process(
            file=args.file,
            lines=args.lines,
            indexes=args.indexes if args.indexes else None,
            num_worker=args.num_worker,
            num_returns=args.num_returns,
            save_path=args.save_path,
            resume=args.resume,
            max_tokens=args.max_tokens,
            async_mode=not args.no_async,
            only_solve_id_conflict=args.only_solve_id_conflict,
            batch_id=args.batch_id,
        )
    else:
        raise ValueError(
            f"Unknown task: {args.task}. Supported tasks are 'split', 'mask', and 'recover'."
        )


if __name__ == "__main__":
    asyncio.run(main())
