from __future__ import annotations
from typing import List, Optional, Union
from pathlib import Path

import os, signal, subprocess, time

import backoff
import openai
import re

from openai import (
    APITimeoutError,
    APIConnectionError,
    RateLimitError,
    BadRequestError,
    AuthenticationError,
    NotFoundError,
    ConflictError,
    UnprocessableEntityError,
    InternalServerError,
    APIStatusError,
)

import re
from tqdm import tqdm

class LemmaExtractor:
    def __init__(self, code: str, name='test'):
        self.code = code
        
        self.header = self.get_header(self.code)
        
        self.possible_tactics = ['nlinarith', 'norm_cast', 'norm_num', 'ring_nf', 'ring']
        
        self.name = name
        
    
    def get_lemma(self, state, negate_statement=False):
        
        # Handle only first sorry, since translator does not return any other
        given_block, goal_block = self.get_statement(state['sorries'][0]['goal'], negate_statement)

        # Purely for better formatting
        if given_block == ':':
            s = self.header + f'lemma {self.name}'
            s += given_block + '\n' + goal_block
        else:
            s = self.header + f'lemma {self.name}\n' 
            s += given_block + goal_block
            
        return s
    
    def get_statement(self, state, negate_statement=False):
        # Split off the 'goal' part
        # Using rsplit with maxsplit=1 to ensure we only split on the last '⊢'
        if '⊢' in state:
            given, goal = state.rsplit('⊢', 1)
        else:
            # No '⊢', treat the entire string as 'given'
            given = state
            goal = ''

        # Split the given part into lines
        lines = given.splitlines()

        # Merge multi-line statements by checking indentation
        merged_lines = []
        for line in lines:
            # If this line starts with whitespace, treat it as a continuation
            # of the previous line.
            if line.strip() and (line[0].isspace() or line.startswith(' ')):
                # Append with a newline to keep indentation in the final output
                merged_lines[-1] += "\n" + line
            else:
                # Start a new line in merged_lines
                merged_lines.append(line)

        # Wrap each merged line in parentheses
        # Only do so for non-empty lines
        wrapped = []
        for ml in merged_lines:
            ml = ml.strip()
            if ml:
                wrapped.append(f"  ({ml})")

        # Join everything with newlines, then add the 'goal' part.
        # For Lean, you often want a " :\n" before the goal chunk:
        # but you can adapt as needed.
        given_block = '\n'.join(wrapped) + ' :\n' if wrapped else ':'
        
        # Format the goal
        if goal.strip():
            # If goal is not empty, prepend two spaces and finish with := by\n  sorry and negate the statement
            if negate_statement:
                goal_block = '  ¬(' + goal.strip() + ') := by\n  sorry'
            else:
                goal_block = '  ' + goal.strip() + ' := by\n  sorry'
        else:
            goal_block = ''

        return given_block, goal_block
    
    def get_header(self, code):
        header = []
        for c in code.splitlines():
            # for h in ['import', 'open', 'set_option']:
            #     if c.strip().startswith(h):
            #         header.append(c)
            if c.lstrip().startswith('theorem'):
                break
            header.append(c)
        return '\n'.join(header) + '\n\n'



class TranslatorModel:
    def __init__(self, model_path: str,
                 sampling_params,
                 template: Optional[Union[str, Path]] = None,
                 api_key = None,
                 base_url = None,
                 port = None) -> None:
        """
        Parameters
        ----------
        model_path : path to your model weights or identifier (optional)
        template   : either a filesystem path to a prompt template, or a raw template string.
                     The template should contain placeholders:
                     <question>, <answer>, <ground_truth>
        """
        self.model_path = model_path if model_path is not None else None

        self.sampling_params = sampling_params

        self.api_key = api_key
        self.base_url = base_url
        self.port = port

        # Load template from file if it's a path; otherwise treat as raw text.
        if isinstance(template, (str, Path)) and Path(str(template)).exists():
            self.template_text = Path(str(template)).read_text(encoding="utf-8")
        else:
            self.template_text = str(template) if template is not None else ""

        # Hook point: load a model if you need to (override in subclass).
        self._load_model()

        self.error_feedback = ''

    # ---------- Overridables / hooks ----------
    # @backoff.on_exception(backoff.expo, openai.APIError)
    def completions_with_backoff(self, model: str, **kwargs):
        # return self.client.chat.completions.create(model=model, **kwargs)
        # if not model.startswith("gpt") and not model.startswith("deepseek"):
        return self.client.chat.completions.create(model=model, **kwargs)
    
    def run_llm(self, prompt: str, n : int):
        messages = [{"role": "user", "content": prompt}]
        res = self.completions_with_backoff(
                    model=self.model_path,
                    messages=messages,
                    temperature=self.sampling_params.temperature,
                    max_tokens=self.sampling_params.max_tokens,
                    n=n,
                    stop=[],
                    top_p=self.sampling_params.top_p,
                    timeout=self.sampling_params.timeout
                )
        
        outputs = [r.message.content for r in res.choices]

        return outputs

    def _load_model(self) -> None:
        """Override to load your verifier model from self.model_path."""
        if self.model_path in ['deepseek-chat']:
            self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
        else:
            self.client = openai.OpenAI(base_url=f"http://{self.base_url}:{self.port}/v1", api_key=self.api_key)
            # raise NotImplementedError("Implement local vLLM translator here.")

    # ---------- Core functionality ----------

    def render_prompt(self, question: str) -> str:
        """
        Fill the template with provided values.
        If no template was provided, falls back to a minimal inline template.
        """
        tpl = self.template_text or (
            "Question:\n<question>\n"
            "Translate this informal mathematics statement into an equivalent Lean4 statement. Lean4 code should be wrapped as follows:\n```lean4\n[your code goes here]\n```"
        )
        return (
            tpl.replace("<question>", question)
        )

    def extract_code(self, text_input):
        """Extracts the last Lean 4 code block from the model's output."""
        try:
            matches = re.findall(r'```lean4\n(.*?)\n```', text_input, re.DOTALL)
            # return matches[-1].strip() if matches else "No Lean 4 code block found."
            return matches[-1].strip() if matches else text_input
        except Exception:
            return "Error during code extraction."
    
    def autoformalize(self, question, n):
        prompt = self.render_prompt(question)
        # translations = self.run_llm(prompt, n)

        try:
            translations = self.run_llm(prompt, n)
        except APITimeoutError as e:
            self.error_feedback = 'Autoformalizer timed out, the step is too complex. Try to break it down and prove smaller steps.'
        except APIConnectionError as e:
            self.error_feedback = 'Network error, try again.'
        # TODO: Add more error handling if necessary

        for i, t in enumerate(translations):
            translations[i] = self.extract_code(t)

        return translations
    
    def split_header_and_body(self, statement):
        header, body = [], []

        for s in statement.splitlines():
            if any([s.lstrip().startswith(t) for t in ['open', 'import', 'set_option']]):
                header.append(s)
            else:
                if s.strip():
                    body.append(s)
        header = '\n'.join(header) + '\n\n'
        body = '\n'.join(body)

        return header, body

    
    def negate_statement(self, statement, repl_output, scheduler=None):

        lemma_extractor = LemmaExtractor(statement)
        opp = lemma_extractor.get_lemma(repl_output, negate_statement=True)

        opp = opp.replace('sorry', 'push_neg ; sorry')

        return opp
    
    def construct_from_repl(self, statement, repl_output):
        # statement is needed only for header extraction
        lemma_extractor = LemmaExtractor(statement)
        lemma = lemma_extractor.get_lemma(repl_output, negate_statement=False)

        return lemma




