from __future__ import annotations
from typing import List, Optional, Union
from pathlib import Path

import os, signal, subprocess, time

import backoff
import openai
import re

from openai import (
    APITimeoutError,
    APIConnectionError,
    RateLimitError,
    BadRequestError,
    AuthenticationError,
    NotFoundError,
    ConflictError,
    UnprocessableEntityError,
    InternalServerError,
    APIStatusError,
)

class ProverModel:
    def __init__(self, model_path: str,
                 sampling_params,
                 template: Optional[Union[str, Path]] = None,
                 api_key = None,
                 base_url = None,
                 port = None) -> None:
        """
        Parameters
        ----------
        model_path : path to your model weights or identifier (optional)
        template   : either a filesystem path to a prompt template, or a raw template string.
                     The template should contain placeholders:
                     <question>, <answer>, <ground_truth>
        """
        self.model_path = model_path if model_path is not None else None

        self.sampling_params = sampling_params

        self.api_key = api_key
        self.base_url = base_url
        self.port = port

        # Load template from file if it's a path; otherwise treat as raw text.
        if isinstance(template, (str, Path)) and Path(str(template)).exists():
            self.template_text = Path(str(template)).read_text(encoding="utf-8")
        else:
            self.template_text = str(template) if template is not None else ""

        # Hook point: load a model if you need to (override in subclass).
        self._load_model()

        # Default header if not passed
        self.header = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"

        self.error_feedback = ''

    # ---------- Overridables / hooks ----------
    # @backoff.on_exception(backoff.expo, openai.APIError)
    def completions_with_backoff(self, model: str, **kwargs):
        # return self.client.chat.completions.create(model=model, **kwargs)
        # if not model.startswith("gpt") and not model.startswith("deepseek"):
        return self.client.chat.completions.create(model=model, **kwargs)
    
    def run_llm(self, prompt: str, n : int):
        messages = [{"role": "user", "content": prompt}]
        res = self.completions_with_backoff(
                    model=self.model_path,
                    messages=messages,
                    temperature=self.sampling_params.temperature,
                    max_tokens=self.sampling_params.max_tokens,
                    n=n,
                    stop=[],
                    top_p=self.sampling_params.top_p,
                    timeout=self.sampling_params.timeout
                )
        
        outputs = [r.message.content for r in res.choices]

        return outputs

    def _load_model(self) -> None:
        """Override to load your verifier model from self.model_path."""
        if self.model_path in ['deepseek-chat']:
            self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
        else:
            self.client = openai.OpenAI(base_url=f"http://{self.base_url}:{self.port}/v1", api_key=self.api_key)
            # raise NotImplementedError("Implement local vLLM translator here.")

    # ---------- Core functionality ----------

    def render_prompt(self, question: str, header : str) -> str:
        """
        Fill the template with provided values.
        If no template was provided, falls back to a minimal inline template.
        """
        tpl = self.template_text or (
            "Question:\n<header>\n<body>\n"
            "Prove this theorem in Lean4. Lean4 code should be wrapped as follows:\n```lean4\n[your code goes here]\n```"
        )
        return (
            tpl.replace("<header>", header).replace("<body>", question)
        )

    def extract_code(self, text_input):
        """Extracts the last Lean 4 code block from the model's output. Removes headers"""

        try:
            matches = re.findall(r'```lean4\n(.*?)\n```', text_input, re.DOTALL)

            code = matches[-1].strip() if matches else "No Lean 4 code block found."

            code_without_header = []
            for l in code.splitlines():
                if any([l.lstrip().startswith(h) for h in ['import', 'set_option', 'open']]):
                    continue
                if not l.strip():
                    continue
                code_without_header.append(l)

            return '\n'.join(code_without_header)
        except Exception as e:
            return "Error during code extraction."
    
    def prove(self, question, n, header=None, without_prover_model=False):
        if header is None:
            header = self.header
        prompt = self.render_prompt(question, header)

        if without_prover_model:
            auto_solver = 'try omega ; try decide ; try norm_cast ; try norm_num ; try simp_all ; try ring_nf at * ; try native_decide ; try linarith ; try nlinarith'
            proofs = [question.replace('sorry', auto_solver)]
            print(proofs)
        else:
            try:
                proofs = self.run_llm(prompt, n)
            except APITimeoutError as e:
                self.error_feedback = 'Prover timed out, the step is too complex. Try to break it down and prove smaller steps.'
                return [None] * n
            except APIConnectionError as e:
                self.error_feedback = 'Network error, try again.'
                return [None] * n
            except Exception as e:
                self.error_feedback = f'{e}, try again.'
                return [None] * n
            # TODO: Add more error handling if necessary

        for i, proof in enumerate(proofs):
            proofs[i] = f'{header}\n\n{self.extract_code(proof)}'


        return proofs



