from abc import ABC, abstractmethod
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import re

LEAN4_DEFAULT_HEADER = (
    "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"
)

class BaseProver(ABC):
    """
    Abstract base class for Lean 4 code provers using LLMs.
    """
    def __init__(self, model_path, gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = LLM(
            model=model_path,
            seed=seed,
            swap_space=8,
            tensor_parallel_size=gpu,
            max_model_len=max_model_len,
            download_dir="/data0/Anony/hub",  # Specify download directory
        )
        self.sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            n=n
        )

    @abstractmethod
    def build_prompt(self, data, example=None):
        """
        Build the prompt for the model based on the input data and optional example.
        """
        pass

    @abstractmethod
    def postprocess(self, model_input, model_output):
        """
        Postprocess the model output to extract the relevant Lean 4 code.
        """
        pass

    def __call__(self, data_list: list = [], use_tqdm: bool = True):
        """
        Generate results for a list of data items, supporting nested prompt lists.
        Handles empty dictionaries by filtering them out before processing and inserting them back at the end.
        """
        num_data = len(data_list)
        
        # 1. record empty dict positions and filter out empty dicts
        empty_dict_positions = []
        non_empty_data = []
        
        for i, data in enumerate(data_list):
            if data == {}:  # check if the data is an empty dict
                empty_dict_positions.append(i)
            else:
                non_empty_data.append(data)
        
        # if there are no non-empty data, return empty results
        if not non_empty_data:
            return [{} for _ in range(num_data)]
        
        # 2. build prompts for non-empty data
        headers = [data.get('header', LEAN4_DEFAULT_HEADER) for data in non_empty_data]
        model_inputs = [self.build_prompt(data) for data in non_empty_data]

        if len(non_empty_data) > 0:
            print("Example of model prompt: ", model_inputs[0])

        # 3. generate model outputs
        prover_outputs = self.model.generate(
            model_inputs,
            self.sampling_params,
            use_tqdm=use_tqdm,
        )
        assert len(prover_outputs) == len(model_inputs)

        # 4. process non-empty data results
        non_empty_results = []
        for i in range(len(non_empty_data)):
            non_empty_results.append(self.postprocess(headers[i], model_inputs[i], prover_outputs[i]))
        
        # 5. combine results by original positions
        total_results = []
        non_empty_idx = 0
        
        for i in range(num_data):
            if i in empty_dict_positions:
                total_results.append({})  # insert empty dict
            else:
                total_results.append(non_empty_results[non_empty_idx])
                non_empty_idx += 1
        
        return total_results

# Utility function for extracting Lean 4 code blocks from text

def extract_lean4_code(text):
    """
    Extract the first Lean 4 code block from the given text.
    """
    if "```lean4" in text:
        text = text.rsplit("```lean4", 1)[1]
        text = "```lean4\n" + text
    else:
        text = "```lean4\n" + text
    match = re.search(r'```lean4\n([\s\S]*?)\n```', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    # Fallback: match everything after ```lean4
    match = re.search(r'```lean4\n([\s\S]*)', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    match = re.search(r'```lean\n([\s\S]*?)\n```', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    match = re.search(r'```lean\n([\s\S]*)', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    # match every after lean4
    match = re.search(r'lean4\n([\s\S]*?)\n```', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return "None"

class GoedelProver(BaseProver):
    def __init__(self, model_path="Goedel-LM/Goedel-Prover-SFT", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        """
        Build the prompt for GoedelProver.
        """
        example = data.get('counter_example', "None")
        prompts = (
            f"Complete the following Lean 4 code using the given concrete example {example}:\n\n"
            f"```lean4\n{data.get('header', LEAN4_DEFAULT_HEADER)}\n"
            f"{data.get('informal_prefix', '')}{data['formal_statement']}"
        )
        return prompts

    def postprocess(self, header, model_input, model_output):
        raise NotImplementedError("GoedelProver does not support postprocess")

class DeepSeekProverV15RL(BaseProver):
    def __init__(self, model_path="deepseek-ai/DeepSeek-Prover-V1.5-RL", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        example = data.get('counter_example', "None")
        prompts = (
            f"Complete the following Lean 4 code using the given concrete example {example}:\n\n"
            f"```lean4\n{data.get('header', LEAN4_DEFAULT_HEADER)}\n"
            f"{data.get('informal_prefix', '')}{data['formal_statement']}"
        )
        return prompts

    def postprocess(self, header, model_input, model_output):
        raise NotImplementedError("DeepSeekProverV15RL does not support postprocess")

class DeepSeekProverV2CoT(BaseProver):
    def __init__(self, model_path="deepseek-ai/DeepSeek-Prover-V2-7B", gpu=1, max_model_len=8192, temperature=1.0, max_tokens=8192, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        example = data.get('counter_example', "None")
        header = fix_header(data.get('header', LEAN4_DEFAULT_HEADER))
        prompt = (
            f"Complete the following Lean 4 code using the provided counter example {example}:\n"
            f"```lean4\n{header}\n{data.get('informal_prefix', '')}\n{data['formal_statement']}\n```\n"
            "Before producing the Lean 4 code to formally prove the given theorem, provide a detailed proof plan outlining the main proof steps and strategies.\n"
            "The plan should highlight key ideas, intermediate lemmas, and proof structures that will guide the construction of the final formal proof."
        )
        return prompt

    def postprocess(self, header, model_input, model_output):
        def extract_cot_code(text):
            # Match after '### Complete Lean 4 Proof' and code block
            match = re.search(r'### Complete Lean 4 Proof\s*\n*```lean4\n([\s\S]*?)\n```', text, re.DOTALL)
            if match:
                return match.group(1).strip()
            match = re.search(r'### Complete Lean 4 Proof\s*\n*```lean4\n([\s\S]*)', text, re.DOTALL)
            if match:
                return match.group(1).strip()
            # Fallback to normal code block
            return extract_lean4_code(text)
        outputs = [output.text for output in model_output.outputs]
        full_codes = [extract_cot_code(out) for out in outputs]
        full_codes = [header + "\n" + out if header not in out else out for out in full_codes]
        return {
            "prover_input": model_input,
            "prover_outputs": outputs, # full outputs from the model
            "formal_proof": full_codes, # Lean 4 code from the model
        }

class KiminaProver(BaseProver):
    def __init__(self, model_path="AI-MO/Kimina-Prover-Preview-Distill-7B", gpu=1, max_model_len=16384, temperature=0.6, max_tokens=8192, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        formal_statement = f"{fix_header(data.get('header', LEAN4_DEFAULT_HEADER))}{data.get('informal_prefix', '')}{data['formal_statement']}"
        example = data.get('counter_example', "None")
        prompts = (
            f"Think about and solve the following problem step by step in Lean 4 using the provided counter example {example}.\n"
            f"# Formal statement:\n```lean4\n{formal_statement}\n```\n"
        )
        return prompts

    def postprocess(self, header, model_input, model_output):
        outputs = [output.text for output in model_output.outputs]
        full_codes = [extract_lean4_code(out) for out in outputs]
        full_codes = [header + "\n" + out if header not in out else out for out in full_codes]
        return {
            "prover_input": model_input,
            "prover_outputs": outputs,
            "formal_proof": full_codes,
        }

class DeepSeekProverV2nonCoT(BaseProver):
    def __init__(self, model_path="deepseek-ai/DeepSeek-Prover-V2-7B", gpu=1, max_model_len=8192, temperature=1.0, max_tokens=8192, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        formal_statement = f"{fix_header(data.get('header', LEAN4_DEFAULT_HEADER))}\n\n{data.get('informal_prefix', '')}{data['formal_statement']}\n"
        example = data.get('counter_example', "None")
        prompt = (
            f"Complete the following Lean 4 code using the provided counter example {example}:\n"
            f"```lean4\n{formal_statement}\n"
        )
        return prompt

    def postprocess(self, header, model_input, model_output):
        outputs = [output.text for output in model_output.outputs]
        full_codes = [extract_lean4_code(model_input + out) for out in outputs]
        return {
            "prover_input": model_input,
            "prover_outputs": outputs, # full outputs from the model
            "formal_proof": full_codes, # Lean 4 code from the model
        }

class STP(BaseProver):
    def __init__(self, model_path="kfdong/STP_model_Lean_0320", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        examples = data.get('counter_example', "None")
        prompts = [(
            f"Complete the following Lean 4 code using the provided counter example {example}:\n\n"
            f"```lean4\n{fix_header(data.get('header', LEAN4_DEFAULT_HEADER))}{data.get('informal_prefix', '')}{data['formal_statement']}"
        ) for example in examples]
        return prompts

    def postprocess(self, model_input, model_output):
        raise NotImplementedError("Leana does not support postprocess")

class Leana(BaseProver):
    def __init__(self, model_path="stoney0062/Leanabell-Prover-GD-RL", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        examples = data.get('counter_example', "None")
        prompts = [(
            f"Complete the following Lean 4 code using the provided counter example {example}:\n\n"
            f"```lean4\n{data.get('header', LEAN4_DEFAULT_HEADER)}{data.get('informal_prefix', '')}{data['formal_statement']}"
        ) for example in examples]
        return prompts

    def postprocess(self, model_input, model_output):
        raise NotImplementedError("Leana does not support postprocess")
        
def fix_header(header):
    if "import Mathlib" not in header:
        header = "import Mathlib\n" + header
    if "import Aesop" not in header:
        # header = header.replace("import Mathlib", "import Mathlib\nimport Aesop\n")
        header = header.replace("import Mathlib", LEAN4_DEFAULT_HEADER)
    num = header.count("import Mathlib")
    if num > 1:
        raise ValueError("Multiple import Mathlib in header")
        exit()
    return header

class ErdosProver(BaseProver):
    def __init__(self, model_path="", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        if model_path == "":
            raise ValueError("model_path is required for ErdosProver")
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        example = data.get('counter_example', "None")
        header = fix_header(data.get('header', LEAN4_DEFAULT_HEADER))
        prompt = (
            f"Complete the following Lean 4 code using the given concrete example {example}:\n\n"
            f"```lean4\n{header}\n"
            f"{data.get('informal_prefix', '')}{data['formal_statement']}\n"
        )
        return prompt

    def postprocess(self, header, model_input, model_output):
        outputs = [output.text for output in model_output.outputs]
        full_codes = [extract_lean4_code(model_input + out) for out in outputs]
        return {
            "prover_input": model_input,
            "prover_outputs": outputs,
            "formal_proof": full_codes,
        }