from abc import ABC, abstractmethod
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from typing import Dict, Any

def parse_answer(sol):
    """ parse the answer from the solution """
    idx = sol.rfind("\\boxed")
    if idx < 0:
        idx = sol.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(sol):
        if sol[i] == "{":
            num_left_braces_open += 1
        if sol[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    
    if right_brace_idx == None:
        retval = None
    else:
        retval = sol[idx:right_brace_idx + 1]
    
    return remove_boxed(retval)

def parse_tokenizer_mode(model_path):
    if "mistral" in model_path:
        return "mistral"
    else:
        return "auto"

class BaseSolver(ABC):
    """
    Abstract base class for generating counter examples using LLMs.
    """
    def __init__(self, model_path, gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, tokenizer_mode="auto", **kwargs):
        self.model_path = model_path
        tokenizer_mode = parse_tokenizer_mode(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        print(self.tokenizer)
        self.model = LLM(
            model=model_path,
            seed=seed,
            swap_space=8,
            tensor_parallel_size=gpu,
            max_model_len=max_model_len,
            download_dir="/data0/Anony/hub",  # Specify download directory
            tokenizer_mode=tokenizer_mode,
            gpu_memory_utilization=0.9,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            n=n
        )

    @abstractmethod
    def build_prompt(self, data: Dict[str, Any]) -> str:
        """
        Build the prompt for the model based on the input data and optional example.
        """
        pass

    @abstractmethod
    def postprocess(self, model_input, model_output) -> Dict[str, Any]:
        """
        Postprocess the model output to extract the relevant Lean 4 code.
        """
        pass

    def __call__(self, data_list: list, use_tqdm: bool = True):
        """
        Generate results for a list of data items.
        """
        model_inputs = [self.build_prompt(data) for data in data_list]
        
        if num_data > 0:
            print("Example of solver prompt: ", model_inputs[0])
            
        model_outputs = self.model.generate(
            model_inputs,
            self.sampling_params,
            use_tqdm=use_tqdm,
        )
        assert len(model_outputs) == len(model_inputs)
        total_results = []
        for i, data in enumerate(data_list):
            total_results.append(self.postprocess(model_inputs[i], model_outputs[i]))
        return total_results
    
def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None


class MistralSolver(BaseSolver):
    def __init__(self, model_path="mistralai/Mistral-7B-Instruct-v0.3", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        return (f"Find a concrete example for proving the following existential problem.\n"
                "Note that:\n"
                "1. Please reason the problem and give the final answer in Natural Language\n"
                "2. The final answer should be in the format `\\boxed{{...}}`\n\n"
                f"The problem is: {data['formal_statement']}")

    def postprocess(self, model_input, model_output):
        # Note: there are multiple outputs, we need to parse the answer from each output
        outputs = [output.text for output in model_output.outputs]
        counter_examples = [parse_answer(out) for out in outputs]
        return {
            "solver_input": model_input,
            "solver_outputs": outputs, # full outputs from the model
            "counter_examples": counter_examples, # counter examples from the model
        }
        
class DeepSeekQwen8BSolver(BaseSolver):
    def __init__(self, model_path="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        return (f"Find a concrete example for proving the following existential problem.\n"
                "Note that:\n"
                "1. Please reason the problem and give the final answer in Natural Language\n"
                "2. The final answer should be in the format `\\boxed{{...}}`\n\n"
                f"The problem is: {data['formal_statement']}")
        
    def postprocess(self, model_input, model_output):
        # Note: there are multiple outputs, we need to parse the answer from each output
        outputs = [output.text for output in model_output.outputs]
        counter_examples = [parse_answer(out) for out in outputs]
        return {
            "solver_input": model_input,
            "solver_outputs": outputs, # full outputs from the model
            "counter_examples": counter_examples, # counter examples from the model
        }
        
class Qwen3Solver(BaseSolver):
    def __init__(self, model_path="Qwen/Qwen3-8B", gpu=1, max_model_len=32768, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)
        
    def build_prompt(self, data):
        return (f"Find a concrete example for proving the following existential problem.\n"
                "Note that:\n"
                "1. Please reason the problem and give the final answer in Natural Language\n"
                "2. The final answer should be in the format `\\boxed{{...}}`\n\n"
                f"The problem is: {data['formal_statement']}")
    
    def postprocess(self, model_input, model_output):
        # Note: there are multiple outputs, we need to parse the answer from each output
        outputs = [output.text for output in model_output.outputs]
        counter_examples = [parse_answer(out) for out in outputs]
        return {
            "solver_input": model_input,
            "solver_outputs": outputs, # full outputs from the model
            "counter_examples": counter_examples, # counter examples from the model
        }
        
class GPTOSSSolver(BaseSolver):
    def __init__(self, model_path="openai/gpt-oss-20b", gpu=1, max_model_len=32768, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        return (f"Find a concrete example for proving the following existential problem.\n"
                "Note that:\n"
                "1. Please reason the problem and give the final answer in Natural Language\n"
                "2. The final answer should be in the format `\\boxed{{...}}`\n\n"
                f"The problem is: {data['formal_statement']}")

    def postprocess(self, model_input, model_output):
        # Note: there are multiple outputs, we need to parse the answer from each output
        outputs = [output.text for output in model_output.outputs]
        counter_examples = [parse_answer(out) for out in outputs]
        return {
            "solver_input": model_input,
            "solver_outputs": outputs, # full outputs from the model
            "counter_examples": counter_examples, # counter examples from the model
        }
        
class ErdosSolver(BaseSolver):
    def __init__(self, model_path="", gpu=1, max_model_len=4096, temperature=1.0, max_tokens=2048, top_p=0.95, n=32, seed=0, **kwargs):
        if model_path == "":
            raise ValueError("model_path is required for ErdosSolver")
        super().__init__(model_path, gpu, max_model_len, temperature, max_tokens, top_p, n, seed, **kwargs)

    def build_prompt(self, data):
        return (f"Find a concrete example for proving the following existential problem.\n"
                "Note that:\n"
                "1. Please reason the problem and give the final answer in Natural Language\n"
                "2. The final answer should be in the format `\\boxed{{...}}`\n\n"
                f"The problem is: {data['formal_statement']}")

    def postprocess(self, model_input, model_output):
        # Note: there are multiple outputs, we need to parse the answer from each output
        outputs = [output.text for output in model_output.outputs]
        counter_examples = [parse_answer(out) for out in outputs]
        return {
            "solver_input": model_input,
            "solver_outputs": outputs, # full outputs from the model
            "counter_examples": counter_examples, # counter examples from the model
        }
        
    def __call__(self, data_list: list, use_tqdm: bool = True):
        """
        Generate results for a list of data items.
        """
        model_inputs = []
        for data in data_list:
            prompt = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": self.build_prompt(data)}],
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False,  # Set to False to strictly disable thinking
            )
            model_inputs.append(prompt)
        print("Example of solver prompt: ", model_inputs[0])
        model_outputs = self.model.generate(
            model_inputs,
            self.sampling_params,
            use_tqdm=use_tqdm,
        )
        assert len(model_outputs) == len(model_inputs)
        total_results = []
        for i, data in enumerate(data_list):
            total_results.append(self.postprocess(model_inputs[i], model_outputs[i]))
        return total_results