from abc import ABC, abstractmethod
from typing import Callable, Any, TypedDict, Literal, Sequence
from tqdm import tqdm
import json
import os
import shutil

from functools import cache

from olym_gen.utils.utils import get_logger, UNKNOWN_INDEX, UNKNOWN_FIELD, get_generator_base
from olym_gen.utils.sample_utils import retrieve_id_from_name
from pathlib import Path
from collections import defaultdict

import asyncio

logger = get_logger()


class ProblemProofDataSample(TypedDict):
    """
    A data sample that contains a problem and its corresponding proof.
    """

    problem: str
    proof: str
    problem_index: int
    proof_index: int
    field: str


class ProblemProofLoadMixin(ABC):

    @staticmethod
    def retrieve_id_from_name(file_name: str) -> tuple[int, int, int]:
        return retrieve_id_from_name(file_name)

    @property
    def problem_retrieve(self) -> Callable[[dict[str, Any]], str]:
        return lambda d: str(d["problem"])

    @property
    def proof_retrieve(self) -> Callable[[dict[str, Any]], list[str]]:
        return lambda d: [str(solution) for solution in d["solution"]]

    @property
    def field_retrieve(self) -> Callable[[dict[str, Any]], str]:
        return lambda d: str(d.get("subfield", UNKNOWN_FIELD))

    @property
    def sample_filter(self) -> Callable[[dict[str, Any]], bool]:
        return lambda d: "geometry" not in str(d.get("subfield", "")).lower()

    def load_problems_and_proofs(
        self,
        file_path: Path,
        lines: int | None = None,
        indexes: Sequence[int] | None = None,
        problem_retrieve: Callable[[dict[str, Any]], str] | None = None,
        proof_retrieve: Callable[[dict[str, Any]], list[str]] | None = None,
        field_retrieve: Callable[[dict[str, Any]], str] | None = None,
        sample_filter: Callable[[dict[str, Any]], bool] | None = None,
        no_log: bool = False,
        file_type: Literal["jsonl", "json"] = "jsonl",
        keep_original_question_index: bool = False,
    ) -> list[ProblemProofDataSample]:
        """
        Read the file as jsonl and retrieve the problems and natural language proofs. The specific resolution could be modified by your defined json format.
        Args:
            file_path (Path): The path to the jsonl or json file.
            lines (int | None): Only read the first n lines from the file to save the time. If None, read all lines. Notice that because of the filter, the actual number of samples read may be less than this number.
            indexes (Sequence[int] | None): The indexes of the samples to read from the file. If None, read all samples. The indexes not in the first n lines will be ignored.
            problem_retrieve (Callable[[dict[str, Any]], str] | None): A function to retrieve the problem from the sample. Default is None, which uses the default problem_retrieve.
            proof_retrieve (Callable[[dict[str, Any]], list[str]] | None): A function to retrieve the proofs from the sample. Default is None, which uses the default proof_retrieve
            field_retrieve (Callable[[dict[str, Any]], str] | None): A function to retrieve the field from the sample. Default is None, which uses the default field_retrieve.
            sample_filter (Callable[[dict[str, Any]], bool] | None): A function to filter the samples. Default is None, which uses the default sample_filter that filters out geometry problems.
            no_log (bool): If True, do not log the process. Default is False.
            file_type (Literal["jsonl", "json"]): The type of the file to read. Default is "jsonl". If "json", it will read the whole file as a single JSON object.
            keep_original_question_index: bool. If True, try to use the same question_index recorded in json file instead of give a new idx.
        Returns:
            list[ProblemProofDataSample]: A list of dicts, where each dict contains a problem and its corresponding proof and the indexes of the problem and proof in the original file.
        """

        if problem_retrieve is None:
            problem_retrieve = self.problem_retrieve
        if proof_retrieve is None:
            proof_retrieve = self.proof_retrieve
        if field_retrieve is None:
            field_retrieve = self.field_retrieve
        if sample_filter is None:
            sample_filter = self.sample_filter

        return_list: list[ProblemProofDataSample] = []
        with file_path.open(encoding="UTF-8") as f:
            if not no_log:
                logger.info(f"Started to process the file {file_path} ...")

            if file_type == "json":
                sample = json.load(f)
                problem = problem_retrieve(sample)
                field = field_retrieve(sample)
                proofs = proof_retrieve(sample)
                return [
                    {
                        "problem": problem,
                        "proof": proof,
                        "problem_index": UNKNOWN_INDEX,
                        "proof_index": i,
                        "field": field,
                    }
                    for i, proof in enumerate(proofs)
                ]

            if file_type != "jsonl":
                raise ValueError(
                    f"Unsupported file type: {file_type}. Only 'jsonl' and 'json' are supported."
                )

            for i, line in tqdm(
                enumerate(f), total=lines, desc="load datasets", disable=no_log
            ):
                if indexes is not None and i not in indexes:
                    continue
                if lines is not None and i >= lines:
                    break
                sample = json.loads(line)
                if not sample_filter(sample):
                    logger.info(f"Skip the question {i} because of the filter.")
                    if lines is not None:
                        lines += 1
                    continue
                problem = problem_retrieve(sample)
                field = field_retrieve(sample)
                proofs = proof_retrieve(sample)
                if keep_original_question_index:
                    problem_index = sample.get("problem_index", i)
                else:
                    problem_index = i
                for j, proof in enumerate(proofs):
                    save_sample: ProblemProofDataSample = {
                        "problem": problem,
                        "proof": proof,
                        "problem_index": problem_index,
                        "proof_index": j,
                        "field": field,
                    }
                    return_list.append(save_sample)
        if not no_log:
            logger.info(
                f"Finish read the files {file_path}. Loaded {len(return_list)} problem proof pairs."
            )
        return return_list


class ProblemProofSaveMixin(ABC):

    @property
    def _save_question_name(self) -> str:
        return "question"

    @property
    def _save_orig_solution_name(self) -> str:
        return "orig_solution"

    @property
    def _save_thinking_name(self) -> str:
        return "thinking"

    @property
    def _save_solution_name(self) -> str:
        return "new_solution"

    @property
    def _save_field_name(self) -> str:
        return "field"

    def save_name(
        self,
        save_path: str,
        problem_index: int,
        proof_index: int | None,
        generation_index: int,
    ) -> str:
        """
        Generate the file name for saving the problem-proof pair.
        """
        if proof_index is None:
            return os.path.join(
                save_path, f"problem_{problem_index}_generate_{generation_index}.json"
            )
        else:
            return os.path.join(
                save_path,
                f"problem_{problem_index}_proof_{proof_index}_generate_{generation_index}.json",
            )

    @property
    @abstractmethod
    def _default_save_path(self) -> str:
        """
        The default save path for the generated samples. This should be implemented by the sub-class.
        """
        ...

    def save_response(
        self,
        input_data: ProblemProofDataSample,
        response: list[tuple[str, str] | None],
        save_path: str | None,
        generation_index: list[int],
        without_proof: bool,
    ) -> None:
        """
        Save the response to a json file, its name is based on the problem index, proof index (if existing), and generation index.
        """

        problem, proof, field, problem_idx, proof_idx = (
            input_data["problem"],
            input_data["proof"],
            input_data["field"],
            input_data["problem_index"],
            input_data["proof_index"],
        )

        if without_proof:
            proof_idx = None

        if save_path is None:
            save_path = self._default_save_path
        os.makedirs(save_path, exist_ok=True)

        if len(response) != len(
            generation_index
        ):
            logger.error("Meet Error when saving the response.")
            logger.error(f"Length of response: {len(response)}")
            for res in response:
                logger.error(f"Response item: {res}")
            logger.error("generation_index: {generation_index}")
            raise ValueError(f"The length of response and generation_index should be the same but got {len(response)} response and {len(generation_index)} generation index.")

        for j, res in enumerate(response):
            if res is None:
                logger.warning(
                    f"Response for problem {problem_idx} proof {proof_idx} generation {j} is None, skipping. It could be due to previous errors. Please check the logs for more details."
                )
                continue

            thinking, solution = res

            save_name = self.save_name(
                save_path=save_path,
                problem_index=problem_idx,
                proof_index=proof_idx,
                generation_index=generation_index[j],
            )

            logger.debug(f"Saving response to {save_name}")
            with open(save_name, "w", encoding="UTF-8") as f:
                json.dump(
                    {
                        self._save_question_name: problem,
                        self._save_orig_solution_name: proof,
                        self._save_thinking_name: thinking,
                        self._save_solution_name: solution,
                        self._save_field_name: field,
                    },
                    f,
                    indent=4,
                    ensure_ascii=False,
                )

    def from_json_to_jsonl(
        self,
        save_path: str,
        read_path: str | Sequence[str] | None = None,
        new_question_name: str | None = None,
        new_orig_solution_name: str | None = None,
        new_field_name: str | None = None,
        new_solution_name: str | None = None,
        other_domain: dict[str, Callable[[str, Any, Any], Any]] | None = None,
    ) -> None:
        """
        Our pipeline generally read a jsonl file as a dataset but we generate different json files to make sure the data is well organized and easy to read. We need to convert these json files to a jsonl file for further processing.
        Convert JSON files containing questions and problems to a JSONL file.
        To make the output JSONL file can be aligned with the input json files, we use the `problem_index` domain to record the index or the problem and use the `source` field to record the original file names.
        Args:
            save_path (str): The path to save the output JSONL file.
            read_path (str | Sequence[str] | None): The path to the input JSON files. If None, use the default save path.
            new_question_name (str | None): The name of the question field in the output JSONL file. If None, use the default question name.
            new_orig_solution_name (str | None): The name of the original solution field in the output JSONL file. If None, use the default original solution name.
            new_field_name (str | None): The name of the field in the output JSONL file. If None, use the default field name.
            new_solution_name (str | None): The name of the solution field in the output JSONL file. If None, use the default solution name.
        """
        if read_path is None:
            read_path = self._default_save_path
        if isinstance(read_path, str):
            read_path = [read_path]
        read_paths = read_path

        if new_question_name is None:
            new_question_name = self._save_question_name
        if new_orig_solution_name is None:
            new_orig_solution_name = self._save_orig_solution_name
        if new_field_name is None:
            new_field_name = self._save_field_name
        if new_solution_name is None:
            new_solution_name = self._save_solution_name

        if other_domain is None:
            other_domain = dict()

        jsonl_dict = {}

        for read_path in read_paths:
            if not os.path.exists(read_path):
                raise FileNotFoundError(f"File {save_path} does not exist.")
            logger.debug(f"Reading files from {read_path} ...")
            for file_name in os.listdir(read_path):
                try:
                    problem_index, proof_index, generation_index = retrieve_id_from_name(
                        file_name
                    )  # Ensure the function is called to validate the file name format
                except ValueError as e:
                    # not a json file, skip it
                    continue

                with open(os.path.join(read_path, file_name), "r", encoding="UTF-8") as f:
                    data = json.load(f)

                if problem_index not in jsonl_dict:
                    jsonl_dict[problem_index] = {
                        new_question_name: data[self._save_question_name],
                        new_orig_solution_name: data[
                            self._save_orig_solution_name
                        ],  # NOTE: that a question may have multiple solutions and this origina solution could not be the one that is rephrased. This original solution is only used to help reader to understand the problem in report.
                        new_field_name: data[self._save_field_name],
                        new_solution_name: [],
                        "source": [],
                    }
                jsonl_dict[problem_index][new_solution_name].append(
                    data[self._save_solution_name]
                )
                jsonl_dict[problem_index]["source"].append(os.path.join(read_path, file_name))

                for domain_name, domain_process_func in other_domain.items():
                    jsonl_dict[problem_index][domain_name] = domain_process_func(
                        file_name, data, jsonl_dict
                    )

        # Write to JSONL file
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, "w", encoding="UTF-8") as f:
            for problem_index, content in sorted(
                jsonl_dict.items(), key=lambda x: x[0]
            ):
                content["problem_index"] = problem_index
                f.write(json.dumps(content, ensure_ascii=False) + "\n")


class ResumeMixin(ProblemProofSaveMixin, ProblemProofLoadMixin):

    @cache
    def load_resume_files(
        self, resume: bool, save_path: Path, check_: bool = False
    ) -> defaultdict[str, list[str]]:
        """
        Load all the files in the save path, count the number of generation for each QA text.

        Returns:
            dict[str, list[str]] (problem_str__proof_str -> [partial_saved_files]) and the value is a list of file names that contain this pair.
            If resume is False, return an empty dictionary.
            If save_path does not exist, return an empty dictionary.
        """
        if not resume:
            logger.debug("Resume is set to False, do not load existing files.")
            return defaultdict(list)
        if not save_path.exists():
            logger.debug("Save path does not exist, do not resume from existing files.")
            return defaultdict(list)
        logger.info(
            f"Try to resume from existing files in the save path {str(save_path)}"
        )
        return_dict: defaultdict[str, list[str]] = defaultdict(list)
        file_name_ext = os.path.splitext(self.save_name("", 0, 0, 0))[1]
        for file in save_path.glob(f"*{file_name_ext}"):
            if not file.is_file():
                logger.error(
                    f"Expected a file but found a directory or non-file object: {file}"
                )
                continue
            data = self.load_problems_and_proofs(
                file,
                problem_retrieve=lambda d: d[self._save_question_name],
                proof_retrieve=lambda d: [str(d[self._save_orig_solution_name])],
                no_log=True,
                file_type="json",
            )
            for d in data:
                problem = d["problem"]
                proof = d["proof"]
                key = f"{problem}__{proof}"
                return_dict[key].append(file.name)
        have_exist = sum(len(v) for v in return_dict.values())
        logger.info(
            f"Find {have_exist} existing files in the save path {str(save_path)}."
        )

        # check
        if check_:
            # TODO: now it will lead to unexpected behavior if the same problem and proof pair exists with different index. Check and try to solve it.
            try:
                for _, files in return_dict.items():
                    # We claim the same qa must have the same idx
                    p_id, a_id, _ = self.retrieve_id_from_name(files[0])
                    for file in files[1:]:
                        p_id2, a_id2, _ = self.retrieve_id_from_name(file)
                        if not (p_id == p_id2 and a_id == a_id2):
                            logger.error(
                                f"Files {files[0]} and {file} have same qa but different problem or proof index."
                            )
                            raise ValueError(
                                f"Files {files[0]} and {file} have same qa but different problem or proof index."
                            )
            except ValueError as e:
                # If the for loop is broken, it means there is an error.
                if not hasattr(self, "pairs"):
                    raise ValueError(
                        "We find error but no pairs provided. We cannot fix the error without pairs. We do not add the pairs attribute anywhere. Please add it by yourself."
                    )
                pairs = self.pairs  # type: ignore
                logger.info(f"Try to fix the error by checking the pairs.")
                with open(self.save_name(str(save_path), p_id, a_id, 0), "r", encoding="UTF-8") as f:  # type: ignore
                    data = json.load(f)
                    problem = data[self._save_question_name]
                    proof = data[self._save_orig_solution_name]
                    expected_problem_id, expected_proof_id = (
                        self._find_problem_proof_index_from_pairs(
                            pairs=pairs,
                            problem=problem,
                            proof=proof,
                        )
                    )
                    if expected_problem_id == UNKNOWN_INDEX:
                        logger.error(
                            f"The problem {problem} and proof {proof} does not exist in the pairs. Cannot fix the error."
                        )
                        raise ValueError(
                            f"The problem {problem} and proof {proof} does not exist in the pairs. Cannot fix the error."
                        )
                for file in files:  # type: ignore
                    ppp, aaa, ggg = self.retrieve_id_from_name(file)
                    if ppp != expected_problem_id or (
                        aaa != expected_proof_id and aaa != UNKNOWN_INDEX
                    ):
                        if not os.path.exists(
                            self.save_name(
                                str(save_path),
                                expected_problem_id,
                                expected_proof_id,
                                ggg,
                            )
                        ):
                            logger.info(
                                f"Move the file {file} to the expected place {self.save_name(str(save_path), expected_problem_id, expected_proof_id, ggg)}."
                            )
                            shutil.move(
                                os.path.join(save_path, file),
                                self.save_name(
                                    str(save_path),
                                    expected_problem_id,
                                    expected_proof_id,
                                    ggg,
                                ),
                            )
                        else:
                            logger.info(
                                f"File {file} already exists in the expected place {self.save_name(str(save_path), expected_problem_id, expected_proof_id, ggg)}. Will remove the file {file} and try to load again."
                            )
                            os.remove(os.path.join(save_path, file))
                return self.load_resume_files(
                    resume=resume,
                    save_path=save_path,
                    check_=check_,
                )

        return return_dict

    def check_resume(
        self,
        resume: bool,
        expected_file_name: str,
        problem: str,
        answer: str,
        pairs: list[ProblemProofDataSample],
    ) -> bool:
        """
        Check if the expected file name already exists in the resume file dictionary.
        If it exists, return True, otherwise return False. Notice that the expected_file_name should be the complete file path.
        The expected_file_name is calculated by the input jsonl file.
        """
        # only keep the the file name
        complete_expected_file_name = expected_file_name
        expected_file_name = os.path.split(expected_file_name)[-1]

        if not resume:
            return False

        # This is a dict to record the save files.
        resume_file_dict = self.load_resume_files(
            resume=resume, save_path=Path(os.path.dirname(complete_expected_file_name))
        )

        key = f"{problem}__{answer}"
        if key not in resume_file_dict or not resume_file_dict[key]:
            if not os.path.exists(complete_expected_file_name):
                logger.debug(
                    f"File {complete_expected_file_name} does not exist, will generate."
                )
                return False
            logger.debug(
                f"File {expected_file_name} has been existint but save different problem or proof. Will generate but first try to move the existing file to the expected place."
            )
            with open(complete_expected_file_name, "r", encoding="UTF-8") as f:
                data = json.load(f)
                problem = data[self._save_question_name]
                answer = data[self._save_orig_solution_name]
            expected_problem_id_for_another, expected_proof_id_for_another = (
                self._find_problem_proof_index_from_pairs(
                    pairs=pairs,
                    problem=problem,
                    proof=answer,
                )
            )
            if expected_problem_id_for_another == UNKNOWN_INDEX:
                logger.debug(
                    f"The current content in {expected_file_name} does not match any problem-proof pair in the input jsonl file. The content will be overwritten."
                )
            else:
                (
                    existing_problem_id_for_another,
                    existing_proof_id_for_another,
                    existing_generation_id_for_another,
                ) = self.retrieve_id_from_name(expected_file_name)
                assert (
                    existing_problem_id_for_another != expected_problem_id_for_another
                ) or (
                    existing_proof_id_for_another != expected_proof_id_for_another
                ), f"The expected problem and proof index {expected_problem_id_for_another}, {expected_proof_id_for_another} should not be the same as the existing problem and proof index {existing_problem_id_for_another}, {existing_proof_id_for_another}."
                self._try_to_solve_conflict(
                    save_path=os.path.dirname(complete_expected_file_name),
                    expected_problem_index=expected_problem_id_for_another,
                    expected_proof_index=(
                        expected_proof_id_for_another
                        if expected_proof_id_for_another != UNKNOWN_INDEX
                        else None
                    ),
                    expected_generation_index=existing_generation_id_for_another,
                    existing_problem_index=existing_problem_id_for_another,
                    existing_proof_index=(
                        existing_proof_id_for_another
                        if existing_proof_id_for_another != UNKNOWN_INDEX
                        else None
                    ),
                    pairs=pairs,
                )
            return False
        if key in resume_file_dict and expected_file_name in resume_file_dict[key]:
            logger.debug(
                f"File {expected_file_name} already exists. Skipping generation."
            )
            return True
        if (
            key in resume_file_dict
            and resume_file_dict[key]
            and not expected_file_name in resume_file_dict[key]
        ):
            # This is the problem_id and proof_id from the save files. And they are different from the index from load jsonl.
            expected_problem_id, expected_proof_id, expected_generation_id = (
                self.retrieve_id_from_name(expected_file_name)
            )
            existing_file_name = resume_file_dict[key][0]
            existing_problem_id, existing_proof_id, _ = self.retrieve_id_from_name(
                existing_file_name
            )
            if (
                expected_problem_id == existing_problem_id
                and expected_proof_id == existing_proof_id
            ):
                logger.debug(
                    f"Problem for {expected_file_name} have partially generated with different generation id. Will generate."
                )
                return False
            else:
                self._try_to_solve_conflict(
                    save_path=os.path.dirname(complete_expected_file_name),
                    expected_problem_index=expected_problem_id,
                    expected_proof_index=expected_proof_id,
                    expected_generation_index=expected_generation_id,
                    existing_problem_index=existing_problem_id,
                    existing_proof_index=existing_proof_id,
                    pairs=pairs,
                )
                return self.check_resume(
                    resume=resume,
                    expected_file_name=complete_expected_file_name,
                    problem=problem,
                    answer=answer,
                    pairs=pairs,
                )
        return False

    def _find_problem_proof_index_from_pairs(
        self,
        pairs: list[ProblemProofDataSample],
        problem: str,
        proof: str,
    ):
        for pair in pairs:
            if pair["problem"] == problem and pair["proof"] == proof:
                return pair["problem_index"], pair["proof_index"]
        return UNKNOWN_INDEX, UNKNOWN_INDEX

    def _try_to_solve_conflict(
        self,
        save_path: str,
        expected_problem_index: int,
        expected_proof_index: int | None,
        expected_generation_index: int,
        existing_problem_index: int,
        existing_proof_index: int | None,
        pairs: list[ProblemProofDataSample],
        max_times: int = 1000,
    ) -> bool:
        """
        expected_problem_index: the index problem in the loaded jsonl file.

        """

        assert (
            expected_problem_index != existing_problem_index
            or expected_proof_index != existing_proof_index
        ), f"The expected problem and proof index {expected_problem_index}, {expected_proof_index} should not be the same as the existing problem and proof index {existing_problem_index}, {existing_proof_index}."

        if max_times <= 0:
            raise RuntimeError(
                f"Failed to solve the index conflict after {max_times} times. "
            )

        existing_file_name = self.save_name(
            save_path,
            existing_problem_index,
            existing_proof_index if expected_proof_index != UNKNOWN_INDEX else None,
            expected_generation_index,
        )
        # This saved file must exist but include another problem or proof. Otherwise, it is not a conflict.
        assert os.path.exists(
            existing_file_name
        ), f"Expected file {existing_file_name} does not exist."

        logger.warning(
            f"Find the problem and question but saved at a different place. "
            f"Expected to save problem index: {expected_problem_index}, proof index: {expected_proof_index}. "
            f"Existing problem index: {existing_problem_index}, proof index: {existing_proof_index}."
            "Try to solve the conflict by moving the existing file to the expected name. "
        )

        expected_file_name = self.save_name(
            save_path,
            expected_problem_index,
            expected_proof_index if expected_proof_index != UNKNOWN_INDEX else None,
            expected_generation_index,
        )
        if not os.path.exists(expected_file_name):
            self._move_to_solve_conflict(
                source=existing_file_name, target=expected_file_name
            )
            return True

        # Otherwise, the expected file already exists, we need to iteratively try to solve the conflict.

        # Below, we collect enough information to call _try_to_solve_conflict again.
        # Now the problem is that the expected file already exists but include another QA.
        # So we need to find the QA's new expected name.
        with open(expected_file_name, "r", encoding="UTF-8") as f:
            data = json.load(f)
        problem = data[self._save_question_name]
        answer = data[self._save_orig_solution_name]
        (
            existing_problem_index_for_another,
            existing_proof_index_for_another,
            existing_generation_index_for_another,
        ) = self.retrieve_id_from_name(expected_file_name)

        expected_problem_index_for_another, expected_proof_index_for_another = (
            self._find_problem_proof_index_from_pairs(
                pairs=pairs,
                problem=problem,
                proof=answer,
            )
        )

        self._try_to_solve_conflict(
            save_path=save_path,
            expected_problem_index=expected_problem_index_for_another,
            expected_proof_index=(
                expected_proof_index_for_another
                if expected_proof_index_for_another != UNKNOWN_INDEX
                else None
            ),
            expected_generation_index=existing_generation_index_for_another,
            existing_problem_index=existing_problem_index_for_another,
            existing_proof_index=(
                existing_proof_index_for_another
                if existing_proof_index_for_another != UNKNOWN_INDEX
                else None
            ),
            pairs=pairs,
            max_times=max_times - 1,
        )
        assert not os.path.exists(
            expected_file_name
        ), f"Expected file {expected_file_name} still exists after trying to solve the conflict."
        # if the recursive call returns, it means we have solved the conflict.
        self._move_to_solve_conflict(
            source=existing_file_name, target=expected_file_name
        )
        return True

    def _move_to_solve_conflict(
        self,
        source: str,
        target: str,
    ) -> None:
        assert os.path.exists(source), f"Source file {source} does not exist."
        assert not os.path.exists(target), f"Target file {target} already exists."
        logger.info(f"Moving file {source} to {target} to solve the conflict.")
        shutil.move(source, target)
        self.load_resume_files.cache_clear()

    def need_generation_idx(
        self,
        resume: bool,
        problem: str,
        answer: str,
        problem_index: int,
        proof_index: int | None,
        save_path: str,
        pairs: list[ProblemProofDataSample],
        num_returns: int = 1,
    ):
        """
        Determine the indices of the generations that need to be performed based on the resume status and existing files.
        """
        if not resume:
            logger.debug("Resume is set to False, the old generation will be covered.")
            return list(range(num_returns))

        need_generation = []
        for i in range(num_returns):
            if not self.check_resume(
                resume=resume,
                expected_file_name=self.save_name(
                    save_path,
                    problem_index=problem_index,
                    proof_index=proof_index,
                    generation_index=i,
                ),
                problem=problem,
                answer=answer,
                pairs=pairs,
            ):
                need_generation.append(i)
        return need_generation


from olym_gen.generator.base_generator import SystemPromptMixin, GeneratorBase


class LogMixin(ABC):
    """
    Mixin class for logging
    """

    @abstractmethod
    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        """Log the start of the total generation process."""
        ...

    @abstractmethod
    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        """Log the finish of the total generation process."""
        ...


class LogProblemMixin(LogMixin, ABC):
    """
    Mixin class for logging problem generation steps.
    """

    @abstractmethod
    def log_step_start(self, problem_index: int) -> str:
        """Log the start of the generation step for a specific problem."""
        ...

    @abstractmethod
    def log_step_finish(self, problem_index: int) -> str:
        """Log the finish of the generation step for a specific problem."""
        ...


class LogProblemProofMixin(LogMixin, ABC):
    """
    Mixin class for logging problem-proof generation steps.
    """

    @abstractmethod
    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        """Log the start of the generation step for a specific problem and proof."""
        ...

    @abstractmethod
    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        """Log the finish of the generation step for a specific problem and proof."""
        ...


class UserProblemProofMixin:
    """
    Mixin class for generating user prompts for problem-proof generation.
    """

    def _user_prompt(self, problem: str, proof: str) -> str:
        """Generate the user prompt for the problem and proof."""
        return f"Problem: {problem}\n\nProof: {proof}\n\n"


class ProblemProofGeneratorBase(
    ResumeMixin,
    ProblemProofLoadMixin,
    SystemPromptMixin,
    LogProblemProofMixin,
    UserProblemProofMixin,
    ABC,
):
    """
    Base class for generating thinking steps and proof for a given problem and proof.
    Now uses composition instead of inheritance for GeneratorBase functionality.
    """
    
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        """
        Initialize the generator with provider configuration.
        Creates the generator_base using the provided configuration.
        """
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")
    
    @property
    def batch_id(self):
        """Delegate to the composed generator_base"""
        return self.generator_base.batch_id if self.generator_base else None
    
    @batch_id.setter  
    def batch_id(self, value):
        """Delegate to the composed generator_base"""
        if self.generator_base:
            self.generator_base.batch_id = value
    
    async def single_turn_request(self, *args, **kwargs):
        """Delegate to the composed generator_base"""
        if not self.generator_base:
            raise ValueError("generator_base not initialized. Make sure to call get_generator_base in __init__")
        return await self.generator_base.single_turn_request(*args, **kwargs)

    async def _generate(
        self,
        problem: str,
        proof: str,
        problem_index: int,
        proof_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
    ) -> list[tuple[str, str] | None]:
        """
        Generate thinking steps and final answer for a given problem and proof using the DeepSeek Reasoner model.
        """

        logger.debug(self.log_step_start(problem_index, proof_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(problem, proof)
        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            max_tokens=max_tokens,
        )

        logger.debug(self.log_step_finish(problem_index, proof_index))

        return return_list

    async def process(
        self,
        file: str,
        lines: int | None = None,
        indexes: Sequence[int] | None = None,
        num_returns: int = 1,
        save_path: str | None = None,
        num_worker: int = 1,
        resume: bool = False,
        max_tokens: int = 64_000,
        async_mode: bool = True,
        only_solve_id_conflict: bool = False,
        batch_id: str | None = None,
    ):
        self.batch_id = batch_id
        save_path = save_path if save_path is not None else self._default_save_path
        logger.info(self.log_start(file, num_worker, num_returns))

        pairs = self.load_problems_and_proofs(
            Path(file), lines, indexes=indexes, keep_original_question_index=False
        )

        semaphore = asyncio.Semaphore(num_worker)

        async def process_pair(pair: ProblemProofDataSample):
            problem, proof, field, problem_idx, proof_idx = (
                pair["problem"],
                pair["proof"],
                pair["field"],
                pair["problem_index"],
                pair["proof_index"],
            )

            need_generation = self.need_generation_idx(
                resume=resume,
                problem=problem,
                answer=proof,
                problem_index=problem_idx,
                proof_index=proof_idx,
                save_path=save_path,
                num_returns=num_returns,
                pairs=pairs,
            )

            if only_solve_id_conflict:
                return

            if not need_generation or len(need_generation) == 0:
                logger.info(
                    f"Problem {problem_idx} proof {proof_idx} already generated all {num_returns} proofs, skipping."
                )
                return

            response = await self._generate(
                problem,
                proof,
                shared_semaphore=semaphore,
                num_returns=len(need_generation),
                problem_index=problem_idx,
                proof_index=proof_idx,
                max_tokens=max_tokens,
            )

            self.save_response(
                input_data=pair,
                response=response,
                save_path=save_path,
                generation_index=need_generation,
                without_proof=False,
            )

        os.makedirs(save_path, exist_ok=True)
        if not async_mode:
            for pair in tqdm(pairs, desc="Processing pairs"):
                await process_pair(pair)
        else:
            tasks = [process_pair(pair) for pair in pairs]
            for task in tqdm(
                asyncio.as_completed(tasks), total=len(tasks), desc="Processing pairs"
            ):
                await task

        logger.info(self.log_finish(file, len(pairs), num_returns, save_path))
