from typing import Optional, List, Dict, Any, Tuple
import re
from leanclient.client import Lean4Client

base_url = ""
api_key = ""

client = Lean4Client(base_url=base_url, api_key=api_key)

def process_verify_result(verify_response: Dict[str, Any]) -> Tuple[bool, List[str]]:
    """处理验证结果"""
    if verify_response.get('error'):
        return False, [str(verify_response['error'])]
    
    response = verify_response.get('response', {})
    messages = response.get('messages', [])
    
    error_messages = []
    has_error = False
    
    for msg in messages:
        if msg.get('severity') == 'error':
            has_error = True
            error_data = msg.get('data', '')
            if error_data:
                error_messages.append(error_data)
    
    return (not has_error), error_messages


def process_batch_results(verify_results: List[Dict[str, Any]], 
                         start_idx: int) -> List[Dict[str, Any]]:
    """处理批次结果"""
    result_map = {r['custom_id']: r for r in verify_results}
    results = []
    
    for idx, item in enumerate(verify_results):
        custom_id = str(start_idx + idx)
        item_copy = {}
        
        if custom_id in result_map:
            verify_result = result_map[custom_id]
            passed, errors = process_verify_result(verify_result)
            item_copy.update({
                'pass': passed,
                'verify_result': verify_result
            })
            if not passed:
                item_copy['errors'] = errors
        else:
            item_copy.update({
                'pass': False,
                'errors': ["No result found"],
                'verify_result': {"error": "No result found"}
            })
            
        results.append(item_copy)
    return results

def process_formal_statement(statement: Optional[str]) -> str:
    """处理形式化语句"""
    if not statement:
        return statement

    statement = str(statement)
    statement = re.sub(r'```(?:lean4|lean)?\s*', '', statement)
    statement = statement.strip()

    if not statement.endswith(':= by sorry'):
        statement = re.sub(r':=\s*by\s*$', ' := by sorry', statement)
        statement = re.sub(r':=\s*$', ' := by sorry', statement)
        if not re.search(r':=.*$', statement):
            statement = statement + ' := by sorry'

    return statement

default_header = ("import Mathlib\nimport Aesop\nimport Mathlib.Tactic\n"
                     "set_option maxHeartbeats 0\nopen BigOperators Real Nat Topology Rat\n\n")

# header_content = item.get(header, default_header) if header else default_header

header_content = default_header

formal_statement = "```lean4\ntheorem sum_fraction_inequality {n : ℕ} (a : Fin n → ℝ) :\n  (∀ i, 1 + a i > 0) →  -- ensure denominators are positive\n  (∑ i : Fin n, (a i)/(1 + a i)) ≥ \n  (∑ i : Fin n, a i)/(1 + ∑ i : Fin n, a i)\n  := by sorry\n```"

formal_statement_content = process_formal_statement(formal_statement)

if formal_statement_content and "import" in formal_statement_content:
    check_content = formal_statement_content
else:
    check_content = header_content + formal_statement_content

verify_requests = [
    {
        "proof": check_content,
        "custom_id": "0"
    },
    {
        "proof": check_content,
        "custom_id": "1"
    }
]

response = client.verify(verify_requests, timeout=100)
verify_results = response['results'] if isinstance(response, dict) and 'results' in response else response

batch_results = process_batch_results(verify_results, 0)

print()