import json
from typing import Any, Callable, TypedDict
from olym_gen.utils.utils import get_logger, UNKNOWN_FIELD, retrieve_id_from_name, get_generator_base

from pathlib import Path

import os
import asyncio
from tqdm import tqdm

from olym_gen.generator.base_generator import GeneratorBase, SystemPromptMixin
from olym_gen.generator.json_format_generator import (
    JsonFormatOutputMixin,
    JsonFormatError,
)
from olym_gen.generator.problem_proof_generator import (
    LogProblemProofMixin,
    ProblemProofSaveMixin,
)

logger = get_logger()


class SnippetPair(TypedDict):
    problem: str
    original_snippet: list[str]
    completion_snippet: list[str]
    problem_index: int
    proof_index: int


class CheckReferenceGeneratorAPI(
    JsonFormatOutputMixin[dict[str, Any]],
    LogProblemProofMixin,
    SystemPromptMixin,
    ProblemProofSaveMixin,
):
    """Given a problem and two proofs in snippet form, check the logical equivalence of the proofs."""
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        # Use composition instead of inheritance for GeneratorBase
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")

    @property
    def system_prompt_file(self) -> str:
        return "prompts/check_reference_system_prompt.txt"

    def _user_prompt(
        self, problem: str, original_snippets: list[str], completion_snippets: list[str]
    ) -> str:
        # Format the snippets for comparison
        original_snippets_str = "\n\n".join(
            [
                f"**Snippet {i+1}**:\n {snippet}"
                for i, snippet in enumerate(original_snippets)
            ]
        )
        completion_snippets_str = "\n\n".join(
            [
                f"**Snippet {i+1}**:\n {snippet}"
                for i, snippet in enumerate(completion_snippets)
            ]
        )

        return f"Problem: {problem}\n\nOriginal proof snippets:\n{original_snippets_str}\n\nCompletion proof snippets:\n{completion_snippets_str}\n\n"

    @property
    def _default_save_path(self) -> str:
        return "save/check_reference"

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Checking proof equivalence for problem {problem_index} and proof {proof_index}..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished checking proof equivalence for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will check the logical equivalence of proof snippets."

    def log_finish(
        self, file: str, num_pairs: int, num_returns: int, save_path: str
    ) -> str:
        return f"Finished processing the file {file}. Verified {num_pairs} proof pairs with {num_returns} checks each. The equivalence results are saved to {save_path}."

    def _check_json_format(
        self, json_object: Any, other_info: dict[str, Any] | None = None
    ) -> None:
        """
        Check if the given dictionary is in the correct JSON format.
        """
        if not isinstance(json_object, dict):
            raise JsonFormatError(f"Expected a dictionary, got {type(json_object)}")
        if not all(isinstance(key, str) for key in json_object.keys()):
            raise JsonFormatError("All keys in the dictionary should be strings")
        if set(json_object.keys()) != {"proof_snippets_equivalent", "proof_correct"}:
            raise JsonFormatError(
                f"The dictionary should contain exactly these keys: `proof_snippets_equivalent` and `proof_correct`, but got {json_object.keys()}"
            )

        # Check proof_snippets_equivalent is a list of booleans
        if not isinstance(json_object["proof_snippets_equivalent"], list):
            raise JsonFormatError(
                f"Expected `proof_snippets_equivalent` to be a list, got {type(json_object['proof_snippets_equivalent'])}"
            )

        for i, val in enumerate(json_object["proof_snippets_equivalent"]):
            if not isinstance(val, bool):
                raise JsonFormatError(
                    f"Expected element {i} of `proof_snippets_equivalent` to be a boolean, got {type(val)}"
                )

        if other_info is not None and not len(
            json_object["proof_snippets_equivalent"]
        ) == other_info.get("num_snippets", 0):
            raise JsonFormatError(
                f"Expected `proof_snippets_equivalent` to have {other_info.get('num_snippets', 0)} elements, but got {len(json_object['proof_snippets_equivalent'])}"
            )

        # Check proof_correct is a boolean
        if not isinstance(json_object["proof_correct"], bool):
            raise JsonFormatError(
                f"Expected `proof_correct` to be a boolean, got {type(json_object['proof_correct'])}"
            )

        # Check logical consistency: proof_correct should be True only if all snippets are equivalent
        if json_object["proof_correct"] and not all(
            json_object["proof_snippets_equivalent"]
        ):
            raise JsonFormatError(
                "`proof_correct` is True, but not all proof snippets are equivalent."
            )

    def _group_continuous_masked_steps(
        self, masked_steps: list[int], proof: list[str]
    ) -> list[str]:
        """
        Group continuous masked steps into blocks.
        """
        if not masked_steps:
            return []

        groups = []
        current_group = proof[masked_steps[0]]

        for i in range(1, len(masked_steps)):
            if masked_steps[i] == masked_steps[i - 1] + 1:
                # Continuous step, add to current group
                current_group += proof[masked_steps[i]]
            else:
                # Non-continuous step, start new group
                groups.append(current_group)
                current_group = proof[masked_steps[i]]

        # Add the last group
        groups.append(current_group)

        return groups

    def _preprocess_completion_data(
        self, data: dict[str, Any]
    ) -> tuple[str, list[str], list[str]]:
        """
        Preprocess the completion data to extract problem and proof snippets.
        Groups continuous masked steps and matches them with completions.
        """
        problem = data.get("question", "")
        original_proof = [d[0] for d in data.get("groundtruth_proof", [])]
        masked_index = data.get("masked_steps", [])
        original_snippets = self._group_continuous_masked_steps(
            masked_index, original_proof
        )
        completion_snippets = [
            step["completion"]
            for step in data.get("completion", {}).get("completion", [])
        ]

        return problem, original_snippets, completion_snippets

    def load_completion_data(
        self, file: str, lines: int | None
    ) -> list[tuple[dict[str, Any], SnippetPair]]:
        """
        Load the completion data from the given directory.
        """
        data_list = []
        for json_file in Path(file).glob("*.json"):
            with open(json_file, "r", encoding="UTF-8") as f:
                data = json.load(f)
            # if 0 in data.get("masked_steps", []):
            #     logger.info(f"Skipped file {json_file} because step 0 is masked")
            #     continue
            problem, original_snippets, completion_snippets = (
                self._preprocess_completion_data(data)
            )
            problem_index, proof_index, _ = retrieve_id_from_name(json_file.name)

            snippet_pair: SnippetPair = {
                "problem": problem,
                "original_snippet": original_snippets,
                "completion_snippet": completion_snippets,
                "problem_index": problem_index,
                "proof_index": proof_index,
            }
            data_list.append((data, snippet_pair))
            if lines is not None and len(data_list) >= lines:
                break

        logger.info(f"Loaded {len(data_list)} data from {file}.")

        return data_list

    def prepare_info_for_json_check(
        self, problem: str, original_snippets: list[str], completion_snippets: list[str]
    ) -> dict[str, Any]:
        """
        Prepare additional information for JSON format checking.
        """
        return {"num_snippets": len(original_snippets)}

    async def _generate(
        self,
        problem: str,
        original_snippets: list[str],
        completion_snippets: list[str],
        problem_index: int,
        proof_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
    ) -> list[tuple[str, str | dict[str, Any], bool] | None]:
        """
        Generate a response for the given problem and proof snippets, checking the JSON format of the response.
        """
        logger.debug(self.log_step_start(problem_index, proof_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(problem, original_snippets, completion_snippets)

        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            max_tokens=max_tokens,
            use_json=True,
        )

        json_decode_return_list: list[tuple[str, str | dict[str, Any], bool] | None] = (
            []
        )

        # Check for the return list and decode JSON
        for generation in return_list:
            if generation is None:
                json_decode_return_list.append(None)
                continue
            if not isinstance(generation, tuple) or len(generation) != 2:
                raise ValueError(
                    f"Expected a tuple of (thinking, answer), got {generation}"
                )

            info_for_json_check = self.prepare_info_for_json_check(
                problem=problem,
                original_snippets=original_snippets,
                completion_snippets=completion_snippets,
            )
            json_object = self._response_to_json(generation[1], info_for_json_check)
            pass_check = json_object is not None

            json_decode_return_list.append(
                (
                    generation[0],
                    json_object if json_object is not None else generation[1],
                    pass_check,
                )
            )

        logger.debug(self.log_step_finish(problem_index, proof_index))
        return json_decode_return_list

    async def process(
        self,
        file: str,
        lines: int | None = None,
        num_returns: int = 1,
        save_path: str | None = None,
        num_worker: int = 1,
        max_tokens=64_000,
        async_mode: bool = True,
        batch_id: str | None = None,
    ) -> None:
        """
        Process the input file and generate responses for each problem-proof pair.
        """
        self.batch_id = batch_id
        save_path = save_path if save_path is not None else self._default_save_path
        logger.info(self.log_start(file, num_worker, num_returns))

        pairs = self.load_completion_data(file, lines)

        semaphore = asyncio.Semaphore(num_worker)

        async def process_pair(pair):
            data, pair = pair

            problem = pair["problem"]
            original_snippets = pair["original_snippet"]
            completion_snippets = pair["completion_snippet"]
            problem_idx = pair["problem_index"]
            proof_idx = pair["proof_index"]

            response = await self._generate(
                problem,
                original_snippets,
                completion_snippets,
                shared_semaphore=semaphore,
                num_returns=num_returns,
                problem_index=problem_idx,
                proof_index=proof_idx,
                max_tokens=max_tokens,
            )

            for j, res in enumerate(response):
                if res is None:
                    logger.warning(
                        f"Response for problem {problem_idx}, proof {proof_idx} is None, skipping."
                    )
                    continue
                (thinking, solution, pass_check) = res
                if not pass_check:
                    logger.warning(
                        f"Response for problem {problem_idx}, proof {proof_idx} did not pass JSON format check, skipping."
                    )
                    continue
                if isinstance(solution, str):
                    logger.warning(
                        f"Response for problem {problem_idx}, proof {proof_idx} is a string, expected JSON object."
                    )
                    continue
                data["proof_snippets_equivalent"] = solution.get(
                    "proof_snippets_equivalent", []
                )
                data["proof_correct"] = solution.get("proof_correct", False)
                data["pass_check"] = pass_check
                data["check_reference_thinking"] = thinking
                save_name = self.save_name(
                    save_path,
                    problem_index=problem_idx,
                    proof_index=proof_idx,
                    generation_index=j,
                )
                logger.debug(f"Saving response to {save_name}")
                with open(save_name, "w", encoding="UTF-8") as f:
                    json.dump(data, f, indent=4, ensure_ascii=False)

        os.makedirs(save_path, exist_ok=True)
        # debug, no async
        # for pair in pairs:
        #     await process_pair(pair)
        if not async_mode:
            for pair in tqdm(pairs, desc="Processing pairs"):
                await process_pair(pair)
        else:
            tasks = [process_pair(pair) for pair in pairs]
            for task in tqdm(
                asyncio.as_completed(tasks), total=len(tasks), desc="Processing pairs"
            ):
                await task

        logger.info(self.log_finish(file, len(pairs), num_returns, save_path))


async def main(sys_argv: list[str] | None = None):
    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args

    base_parser = common_parse_args()
    check_parser = ArgumentParser(
        parents=[base_parser],
        description="Check solutions with referenced steps for problems using a language model.",
    )
    args = check_parser.parse_args(sys_argv)

    generator = CheckReferenceGeneratorAPI(provider=args.provider, model=args.model)
    await generator.process(
        args.file,
        lines=args.lines,
        num_returns=args.num_returns,
        save_path=args.save_path,
        num_worker=args.num_worker,
        max_tokens=args.max_tokens,
        async_mode=not args.no_async,
        batch_id=args.batch_id,
    )
    logger.info("Finished processing all problems.")


if __name__ == "__main__":
    asyncio.run(main())
