from typing import List
from functools import partial

import json
# import concurrent.futures
# from multithread_support import timeout_handler
import os
import openai
import backoff
import re

class AnswerVerifier:
    def __init__(self):
        api_key = os.getenv("DEEPSEEK_API_KEY", "")
        base_url = os.getenv("DEEPSEEK_BASE_URL", "")

        self.client = openai.OpenAI(api_key=api_key, base_url=base_url)

    @backoff.on_exception(backoff.expo, openai.APIError)
    def completions_with_backoff(self, model: str, **kwargs):
        return self.client.chat.completions.create(model=model, **kwargs)
    
    def llm_checker(self, prompt: str) -> List[str]:
        messages = [{"role": "user", "content": prompt}]
        res = self.completions_with_backoff(
                    model='deepseek-chat',
                    messages=messages,
                    temperature=0.7,
                    max_tokens=8192,
                    n=1,
                    stop=[],
                    top_p=0.95,
                    timeout=300
                )
        
        return res.choices[0].message.content
        

    def check_answer(self, question: str, answer, ground_truth: str) -> bool:
        """
        Reads a prompt template from 'prompts/check_answer.txt',
        fills it with the question, answer, and ground truth,
        then uses dsv3() to get a verification result.
        Returns True if the model response contains 'true', else False.
        """
        if answer.strip() == '':
            return False
        
        with open("prompts/check_answer.txt", "r") as f:
            prompt = f.read()
        
        if isinstance(answer, List):
            answer = '\n'.join(answer)

        prompt = prompt.replace("<question>", question)
        prompt = prompt.replace("<answer>", answer)
        prompt = prompt.replace("<ground_truth>", ground_truth)

        response = self.llm_checker(prompt)
        if isinstance(response, List):
            response = response[0]

        if response and "true" in response.lower():
            return True
        return False
    
class StepVerifier(AnswerVerifier):
    def divide_step(self, solution: str) -> List[str]:
        # Divide the solution into individual steps
        with open("prompts/divide_steps.txt", "r") as f:
            prompt = f.read()

        prompt = prompt.replace("<solution>", solution)

        if_pass = False
        try_count = 0
        while not if_pass and try_count < 3:
            try:
                response = self.llm_checker(prompt)
                # The response may be surrounded by ``` or ```json
                response = (
                    re.sub(r"```(json|JSON)", "", response).replace("```", "").strip()
                )
                steps = json.loads(response)
                if_pass = True
                break
            except Exception as e:
                try_count += 1
                print(e)
                continue

        fallback_steps = [step for step in solution.split("\n") if step.strip() != ""]
        # make sure the return value has at least 1 element
        if fallback_steps == []:
            fallback_steps = [""]

        if not if_pass:
            print("Warning: Failed to divide the solution into steps.")
            return fallback_steps
        elif steps == []:
            return fallback_steps
        else:
            return steps
    
    def check_answer(self, question : str, answer: str, divide_into_steps=True):
        """
        Reads a prompt template from 'prompts/formalize.txt',
        fills it with the question,
        then uses dsv3() to get a verification result.
        Returns True if the step requires Lean4 verification and False otherwise.
        """
        if divide_into_steps:
            p = "prompts/formalize.txt"
        else:
            p = "prompts/formalize_single_step.txt"

        with open(p, "r") as f:
            prompt_base = f.read()
        
        if divide_into_steps:
            steps = self.divide_step(answer)
        else:
            steps = [answer]
        
        checked_steps = []

        buffer = ''
        
        for i, step in enumerate(steps, start=1):
            prompt = prompt_base.replace("<question>", question).replace('<step>', step).replace('<buffer>', buffer)

            response = self.llm_checker(prompt)
            if isinstance(response, List):
                response = response[0]

            if response and "true" in response.lower().strip():
                res = True
            else:
                res = False
            
            checked_steps.append({'question' : question, 'previous_steps' : buffer, 'step' : step, 'requires_verification' : res})
            buffer += f'{step}\n'
        return checked_steps

    def check_single_step(self, step : str):
        p = "prompts/formalize_single_step.txt"
        with open(p, "r") as f:
            prompt_base = f.read()
        
        prompt = prompt_base.replace('<step>', step)
        response = self.llm_checker(prompt)
        if isinstance(response, List):
            response = response[0]
        
        if response and "true" in response.lower().strip():
            res = True
        else:
            res = False
        
        return {'step' : step, 'requires_verification' : res}
            


if __name__ == "__main__":
    question = "If $f(x)=5x^2+3x+4$, what is the value of $f(-2)$?"
    llm_answer = "To find the value of \\( f(-2) \\) for the function \\( f(x) = 5x^2 + 3x + 4 \\), follow these steps:\n\n1. **Substitute \\( x = -2 \\) into the function:**\n   \\[\n   f(-2) = 5(-2)^2 + 3(-2) + 4\n   \\]\n\n2. **Calculate each term separately:**\n   - **First term:** \\( 5(-2)^2 = 5 \\times 4 = 20 \\)\n   - **Second term:** \\( 3(-2) = -6 \\)\n   - **Third term:** \\( 4 \\) remains as is.\n\n3. **Add the calculated terms together:**\n   \\[\n   f(-2) = 20 - 6 + 4\n   \\]\n   \\[\n   f(-2) = (20 - 6) + 4 = 14 + 4 = 18\n   \\]\n\n**Final Answer:**\n\\[\n\\boxed{18}\n\\]"
    verifier = StepVerifier()
    res = verifier.check_answer(question, llm_answer, divide_into_steps=True)

    for p in res:
        print(f'{p["previous_steps"]}{p["step"]}\n{p["requires_verification"]}')
        print('-'*100)

