#!/usr/bin/env python3
import os
import json
import anthropic

client = anthropic.Anthropic()

def norm(text):
    lines = text.strip().split('\n')
    if len(lines) >= 2 and lines[0].startswith('```') and lines[-1].startswith('```'):
        return '\n'.join(lines[1:-1]) + '\n'
    return text + '\n'

def clean(path):
    with open(path, 'r', encoding='utf-8') as file:
        text = file.read()
    
    # 分离系统提示和具体指令
    system_prompt = """You are an expert in formal verification, especially in the verifier Why3."""
    
    task_description = """Extract the names of all lemmas from the given Why3 code."""
    
    syntax_patterns = """
    Lemma patterns to recognize:
    - "lemma <name>: <formula>" → extract "<name>"
    - "let lemma <name> (<params>) : <type>" → extract "<name>" 
    - "let rec lemma <name> (<params>) : <type>" → extract "<name>"
    - "lemma <name> (<params>) : <formula>" → extract "<name>"
    - Any variation with "lemma" keyword followed by an identifier
    """
    
    requirements = """
    Output format:
    1. Return ONLY the lemma names, in JSON format
    2. If no lemmas found, return empty JSON array
    3. Do not include any explanatory text, comments, or formatting
    4. Do not include any other text or formatting
    
    Example output:
    ["lemma_name_1", "lemma_name_2", "another_lemma"]
    """
    
    # 组合完整的提示词
    full_prompt = f"{system_prompt}\n\n{task_description}\n\n{syntax_patterns}\n\n{requirements}\n\nWhy3 code to process:\n\n{text}"
    
    response = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=40000,
        stream=True,
        temperature=0,
        messages=[
            {
                "role": "user",
                "content": full_prompt
            }
        ]
    )
    # Handle streaming response
    reply = ""
    for chunk in response:
        if chunk.type == "content_block_delta":
            reply += chunk.delta.text
    
    return json.loads(norm(reply.strip()))

# Collect all .mlw file paths first
file_paths = []
for root, dirs, files in os.walk('./data/why3/common'):
    for file in files:
        if file.endswith('.mlw'):
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

print(f"Found {len(file_paths)} .mlw files to process")

# Process each file
for (i,file_path) in enumerate(file_paths):
    print(f"[{i}/{len(file_paths)}] Processing: {file_path}")
    rel_path = os.path.relpath(file_path, './data/why3/common')
    output_path = os.path.join('./data/why3/no-lemma3', rel_path + '.lemmas')
    if os.path.exists(output_path):
        print(f"[{i}/{len(file_paths)}] Skipping (already exists): {output_path}")
        continue

    lemmas = clean(file_path)
    print(lemmas)
    print('\n\n\n')
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as file:
        file.write('\n'.join(lemmas))

    print(f"[{i}/{len(file_paths)}] Saved: {output_path}")

print("Processing complete!")


