from typing import List

import os
import openai
import backoff
import re

BACKTRANS_PROMPT = '''Rewrite the following Lean statement in natural language. Output only the rewritten statement and nothing else.

```lean
{formal_statement}
```
'''

VERIFICATION_PROMPT = '''
Your job is to determine whether the Backtranslated Statement is logically equivalent to the Original Statement.

Logical equivalence means:
- Both statements assert the same final mathematical claims,
- Extra intermediate steps, explanations, or justifications **DO NOT** affect equivalence,
- Different wording, order, or level of detail also **DO NOT** affect equivalence.

They are **NOT** equivalent only if the final mathematical claims differ.

Return exactly one word: True or False. Do **NOT** provide explanations.

Original Statement:
<original>

Backtranslated Statement:
<backtranslated>

Does the backtranslated statement match the original statement?:
'''

class BackTranslator:
    def __init__(self):
        api_key = os.getenv("DEEPSEEK_API_KEY", "")
        base_url = os.getenv("DEEPSEEK_BASE_URL", "")

        self.client = openai.OpenAI(api_key=api_key, base_url=base_url)

    @backoff.on_exception(backoff.expo, openai.APIError)
    def completions_with_backoff(self, model: str, **kwargs):
        return self.client.chat.completions.create(model=model, **kwargs)
    
    def llm_checker(self, prompt: str, sys_prompt=None) -> List[str]:

        messages = []
        if sys_prompt is not None:
            messages.append({'role': 'system', 'content': sys_prompt})

        messages.append({"role": "user", "content": prompt})
        res = self.completions_with_backoff(
                    model='deepseek-chat',
                    messages=messages,
                    temperature=0.8,
                    max_tokens=8192,
                    n=1,
                    stop=[],
                    top_p=0.95,
                    timeout=300
                )
        
        return res.choices[0].message.content
    

    def backtranslate(self, formal_statement: str):
        prompt = BACKTRANS_PROMPT.format(formal_statement=formal_statement)

        backtranslated = self.llm_checker(prompt)

        return backtranslated
        

    def check_answer(self, original_statement: str, backtrans_statement: str) -> bool:
        """
        Reads a prompt template from 'prompts/check_answer.txt',
        fills it with the question, answer, and ground truth,
        then uses dsv3() to get a verification result.
        Returns True if the model response contains 'true', else False.
        """
        if backtrans_statement.strip() == '':
            return False

        prompt = VERIFICATION_PROMPT.replace("<original>", original_statement)
        prompt = prompt.replace("<backtranslated>", backtrans_statement)

        response = self.llm_checker(prompt)
        if isinstance(response, List):
            response = response[0]

        if response and "true" in response.lower():
            return True
        return False
    

    def backtranslate_and_verify(self, orig: str, lean_statement: str):
        backtranslated_statement = self.backtranslate(lean_statement)

        response = self.check_answer(original_statement=orig, backtrans_statement=backtranslated_statement)

        return response, backtranslated_statement

    
            


if __name__ == "__main__":
    step = '''
Given that:
- For \(x > 0\), \(\tan^{-1}(x) = \cot^{-1}\left(\frac{1}{x}\right)\).
- For \(x < 0\), \(\tan^{-1}(x) = \cot^{-1}(1/x)\).

Prove that:
For x = 3 + 2√2, tan⁻¹(x) is defined since x > 0, and cot⁻¹(1/x) is also defined since 1/x > 0. The sum tan⁻¹(x) + cot⁻¹(1/x) = 2tan⁻¹(x) is valid, and sin(2tan⁻¹(x)) = 2x/(1+x²) = 1/3, so this is a valid solution.
'''
    lean = '''
import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat



theorem test : 
  -- Given conditions
  (∀ x : ℝ, x > 0 → Real.arctan x = Real.arctan (1 / x) + Real.pi / 2) ∧
  (∀ x : ℝ, x < 0 → Real.arctan x = Real.arctan (1 / x) - Real.pi / 2) →
  -- For x = 3 + 2√2
  let x := 3 + 2 * Real.sqrt 2
  -- tan⁻¹(x) is defined since x > 0
  x > 0 ∧
  -- cot⁻¹(1/x) is defined since 1/x > 0
  (1 / x) > 0 ∧
  -- The sum tan⁻¹(x) + cot⁻¹(1/x) = 2tan⁻¹(x) is valid
  Real.arctan x + (Real.arctan (1 / x) - Real.pi / 2) = 2 * Real.arctan x ∧
  -- sin(2tan⁻¹(x)) = 2x/(1+x²) = 1/3
  Real.sin (2 * Real.arctan x) = 2 * x / (1 + x^2) ∧
  2 * x / (1 + x^2) = 1/3 := by sorry
'''

    verifier = BackTranslator()
    
    res, backtrans = verifier.backtranslate_and_verify(step, lean)

    # print(res, '\n\n', backtrans)
    print(f'Original statement:\n{step}\n\nLean statement:\n{lean}\n{"-"*50}\nBacktranslated statement:\n{backtrans}\n{"-"*50}\nVerdict: {res}')

