import requests
import json
import time
import os
import logging
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer
from typing import Optional, List, Dict, Union, Generator, Any
import re
from verify_prover_v2_solutions_api import batch_verify_lean_proofs, get_sandbox_result

def parse_tool_call(text: str) -> dict:
    """Simplified tool call parsing function"""
    match = re.search(r'<tool_calls>(.*?)</tool_calls>', text, re.DOTALL)
    if match:
        try:
            return json.loads(match.group(1).strip())
        except json.JSONDecodeError:
            return {}
    return {}

def format_tool_results_as_user_message(tool_name, tool_result):
    """Format tool call results into clear and readable format"""
    message = "<tool_results>\n"
    message += f"Function: {tool_name}\n"
    message += f"Output: {json.dumps(tool_result, indent=2, ensure_ascii=False)}\n"
    return message+"</tool_results>\n"

def parse_consistency_response(response):
    try:
        answer_json = json.loads(extract_json_blocks(response.split('</think>')[-1])[-1])
        return {
            "pass": not 'incorrect' in answer_json['is_assistant_correct'].lower(),
            "explanations": answer_json["reasons"]
        }
    except:
        return {
            "pass": False,
            "explanations": 'invalid response from judge'
        }

def get_input_prompt(tokenizer, informal_statement):
    msg = [{'role':'system', 'content':qwen3_template},
           {'role':'user', 'content':'Please autoformalize the following problem in Lean 4 with a header. Use the following theorem names: my_favorite_theorem.\n\n' + informal_statement}]
    return tokenizer.apply_chat_template(
        msg,
        tokenize=False,
        add_generation_prompt=True
    )

def get_consistency_prompt(tokenizer, informal_statement, formal_statement):
    msg = [{'role':'user', 'content':consistency_template.replace('{informal_statement}', informal_statement).replace('{formal_statement}', formal_statement)}]
    return tokenizer.apply_chat_template(
        msg,
        tokenize=False,
        add_generation_prompt=True
    )

def setup_logger(log_dir):
    """Setup logger"""
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"inference_{timestamp}.log"
    log_path = os.path.join(log_dir, log_filename)
    
    log_format = '%(asctime)s - %(levelname)s - %(message)s'
    logger = logging.getLogger('inference_pipeline')
    logger.setLevel(logging.INFO)
    
    # Clear existing handlers to avoid duplication
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    
    # Prevent propagation to parent logger to avoid duplicate output
    logger.propagate = False
    
    # File handler
    file_handler = logging.FileHandler(log_path, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    file_formatter = logging.Formatter(log_format)
    file_handler.setFormatter(file_formatter)
    
    # Console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_formatter = logging.Formatter(log_format)
    console_handler.setFormatter(console_formatter)
    
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    return logger, log_path


# def setup_logger(log_dir, gpu_id):
#     """Setup logger"""
#     os.makedirs(log_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     log_filename = f"inference_gpu{gpu_id}_{timestamp}.log"
#     log_path = os.path.join(log_dir, log_filename)
    
#     log_format = '%(asctime)s - %(levelname)s - [GPU-{}] - %(message)s'.format(gpu_id)
#     logger = logging.getLogger(f'inference_gpu_{gpu_id}')
#     logger.setLevel(logging.INFO)
    
#     for handler in logger.handlers[:]:
#         logger.removeHandler(handler)
    
#     file_handler = logging.FileHandler(log_path, encoding='utf-8')
#     file_handler.setLevel(logging.INFO)
#     file_formatter = logging.Formatter(log_format)
#     file_handler.setFormatter(file_formatter)
    
#     console_handler = logging.StreamHandler()
#     console_handler.setLevel(logging.INFO)
#     console_formatter = logging.Formatter(log_format)
#     console_handler.setFormatter(console_formatter)
    
#     logger.addHandler(file_handler)
#     logger.addHandler(console_handler)
    
#     return logger, log_path

def batch_generate(llm, prompts, sampling_params, batch_size=4096):
    """Generate responses in batches"""
    all_outputs = []
    
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i + batch_size]
        batch_outputs = llm.generate(batch, sampling_params, use_tqdm=False)
        all_outputs.extend([output.outputs[0].text for output in batch_outputs])
    
    return all_outputs

def batch_syntax_check(lean4_code_list):
    syntax_results = []
    temp_results = batch_verify_lean_proofs(lean4_code_list, 'lean-v6', use_tqdm=False)
    for code, temp_result in zip(lean4_code_list, temp_results):
        try: 
            syntax_results.append({"pass": temp_result['info']['pass'], "errors": temp_result['info']['errors']})
        except:
            max_count = 3
            retry_result = {'pass':False, 'errors':['Unsuccessful Lean4 Execution']}
            while max_count > 0:
                try:
                    raw_result = get_sandbox_result(
                        code=code.replace('```lean4\n','').replace('\n```',''),
                        cluster_key='lean-v6'
                    )
                    retry_result = {"pass": raw_result['info']['pass'], "errors": raw_result['info']['errors']}
                    break
                except:
                    max_count -= 1
            syntax_results.append(retry_result)
    return syntax_results

def read_jsonl(path):
    res = []
    with open(path, 'r') as f:
        for line in tqdm(f.readlines()):
            res.append(json.loads(line))
    return res

def write_jsonl(data_to_write, path, mode):
    with open(path, mode) as f:
        for x in tqdm(data_to_write):
            line = json.dumps(x, ensure_ascii=False)
            f.write(line + '\n')

def extract_json_blocks(text: str, return_first_only: bool = False) -> Union[Optional[str], List[str]]:
    """
    Extract JSON content between ```json and ``` markers
    
    Args:
        text: Text containing JSON code blocks
        return_first_only: If True, return only the first match; otherwise return list of all matches
        
    Returns:
        If return_first_only=True, return first matching JSON string or None (if no match)
        If return_first_only=False, return list of all matching JSON strings
    """
    # Use regex to match content between ```json and ```
    pattern = r"```json\s*(.*?)\s*```"
    
    # re.DOTALL flag makes . match any character including newlines
    matches = re.findall(pattern, text, re.DOTALL)
    
    if not matches:
        return None if return_first_only else []
    
    return matches[0] if return_first_only else matches

consistency_template = '''Role: Lean & Formal Verification Expert

Input:
- Mathematical_Text: A math problem and its answer (no proof).
- Lean4Code: A Lean 4 theorem statement formalizing the problem. Proof is intentionally omitted (e.g., sorry).

Goal:
Determine if the Lean theorem statement is an exact and faithful formalization of the mathematical problem.  
**Do not evaluate or consider the answer or the proof. Your sole task is to verify the correctness of the formalization.**

Evaluation Stages (All required):

1. Mathematical Text Analysis  
   Identify all structurally and semantically relevant components of the mathematical problem, including variables, types, quantifiers, constraints, logic structure, conclusion, and so on. The analysis should be based on the actual content of the text.

2. Lean4 Code Analysis (ignore proof part)  
   Extract all structurally and semantically relevant components from the Lean statement, including variables, types, conditions, quantifiers, constraints, the final claim, and so on. The analysis should reflect the actual content present in the Lean code.

3. Comparative Analysis  
   Check for exact correspondence between the math and Lean statements; you may refer to aspects like:
   - Semantic alignment, logic structure, and quantifier correctness.
   - Preservation of constraints and boundary assumptions.
   - Accurate typing and use of variables.
   - Strict adherence to Lean's specific syntactic and semantic rules in interpreting the Lean code.
   - Syntactic validity and proper Lean usage (free from errors).
   - Use of symbols and constructs without semantic drift.
   - No missing elements, no unjustified additions, and no automatic corrections or completions.

4. Accuracy Confirmation  
   If correct: clearly confirm why all elements match.  
   If incorrect: list all mismatches and explain how each one affects correctness.

Note: While the analysis may be broad and open to interpreting all relevant features, the final judgment must be based only on what is explicitly and formally expressed in the Lean statement.  
**Do not consider or assess any part of the proof. Your judgment should be entirely about the accuracy of the statement formalization.**

Output Format:
Return exactly one JSON object:
```json
{
    "reasons": "1. Mathematical Text Analysis: [...]2.  Lean4 Code Analysis (: [...]3. Comparative Analysis: [...]4. Accuracy Confirmation: [...match confirmation or list of discrepancies...]",
    "is_assistant_correct": "[Correct/Incorrect]"
}
```

— Start of Mathematical_Text —
{informal_statement}
— End of Mathematical_Text —

— Start of Lean4Code —
{formal_statement}
— End of Lean4Code —
'''.strip()

qwen3_template = '''
You are an expert in mathematics and Lean 4. Your task is to convert natural language problems into valid Lean 4 formal statements (Compatible with Lean 4 v4.9).

Your code must begin with:
```Lean4
import Mathlib
import Aesop
```

You MUST use the provided tools to verify your Lean 4 statements:

- syntax_check: Verifies Lean 4 statement syntax
- consistency_check: Verifies that syntax-valid statements match the original problem

Verification workflow:

- Analyze the problem and create initial Lean 4 statement
- Call syntax_check to verify compilation
- If syntax check passes, call consistency_check
- If any check fails, analyze errors, modify code and restart verification
- Repeat until BOTH checks pass
'''.strip()
