#!/usr/bin/env python3
import os
import subprocess
import anthropic
import csv
import re
from sqlitedict import SqliteDict

lemma_re = [r'let\s+lemma\s+([\w\']+)', r'let\s+rec\s+lemma\s+([\w\']+)', r'^\s*lemma\s+([\w\']+)']
decl_re = [r'let(\s+rec)?(\s+ghost)?(\s+function)?\s+([\w\']+)']

def collect(text, path):
    lemma_names = set()

    for pattern in lemma_re:
        matches = re.findall(pattern, text)
        lemma_names.update(matches)
    
    directory = os.path.dirname(path)
    if directory not in ['./data/why3/common', 'data/why3/common', './data/why3/no-lemma5', 'data/why3/no-lemma5']:
        with SqliteDict(f'{directory}/lemma.db', autocommit=True) as db:
            for l in db.keys():
                lemma_names.add(l)

    return lemma_names

client = anthropic.Anthropic()

def norm(text):
    lines = text.strip().split('\n')
    if len(lines) >= 2 and lines[0].startswith('```') and lines[-1].startswith('```'):
        return '\n'.join(lines[1:-1]) + '\n'
    return text + '\n'

def diff(path, path2):
    proc = subprocess.run(['diff', '-u', path, path2], capture_output=True, text=True)
    if proc.returncode == 0:
        return "CORRECT"
    return proc.stdout

def check(path, path2):
    with open(path, 'r', encoding='utf-8') as file:
        original_text = file.read()

    lemma_names = collect(original_text, path)

    # Calculate diff between original and processed files
    proc = subprocess.run(['diff', '-u', path, path2], capture_output=True, text=True)
    if proc.returncode == 0:
        return "CORRECT"
    diff = proc.stdout
    
    # 改进的提示词结构
    system_prompt = """You are an expert in formal verification and Why3. You validate code transformations to ensure correctness."""
    
    task_description = """Analyze the DIFF between original and processed Why3 code. Check if lemmas/assertions were correctly removed/transformed while preserving essential program constructs."""

    # 简化的转换规则说明
    transformation_rules = """
CORRECT transformations (what SHOULD be changed):

1. REMOVE assertion annotations: "assert { formula }"
2. TRANSFORM lemma declarations:
   - Short: "lemma name: formula" → "let ghost name ... : unit ... = ()"
   - Long: "let lemma name ... = proof" → "let ghost name ... : unit ... = ()"
3. REMOVE lemma calls/applications
4. REMOVE variant clauses from transformed lemmas

Key transformation example:
```
// Original
let lemma example (x: int)
  requires { x > 0 }
  variant { x }
  ensures { x > 0 }
= assert { x > 0 }

// Correct transformation
let ghost example (x: int) : unit
  requires { x > 0 }
  ensures { x > 0 }
= ()
```
"""

    preservation_rules = """
MUST BE PRESERVED (wrong to change or remove):

1. Function/procedure specifications: requires, ensures, variant clauses
2. Loop specifications: invariant, variant clauses  
3. All functional code: function bodies, type definitions, imports
4. Program logic and control structures
5. Semantic meaning of specifications (even if syntax changes)
"""

    analysis_instructions = """
ANALYSIS PRIORITIES (check in this order):

1. HIGH PRIORITY: Are function/procedure specifications preserved?
2. HIGH PRIORITY: Are loop invariants and variants preserved?
3. MEDIUM PRIORITY: Are lemma transformations syntactically correct?
4. LOW PRIORITY: Are assertions properly removed?

OUTPUT FORMAT:
- If correct: "CORRECT". Print "CORRECT" only. Any other text is strictly forbidden.
- If incorrect: "INCORRECT: [Priority] Issue description". Your response must start with "INCORRECT".

Examples:
- "INCORRECT: HIGH PRIORITY - Function requires clause was removed"
- "INCORRECT: MEDIUM PRIORITY - Lemma transformation syntax error"
"""

    # 组合完整的提示词
    lemma_list = ""
    if lemma_names:
        lemma_list = f"""
Lemma names that should be removed when called:
{chr(10).join(sorted(lemma_names))}
"""

    full_prompt = f"""{system_prompt}

{task_description}

{transformation_rules}

{preservation_rules}

{analysis_instructions}
{lemma_list}
DIFF TO ANALYZE:
```
{diff}
```

Analyze and return CORRECT or INCORRECT with specific priority and issue."""
    
    try:
        response = client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=8000,  # 减少token消耗
            temperature=0,
            stream=True,
            messages=[
                {
                    "role": "user",
                    "content": full_prompt
                }
            ]
        )
        
        # Handle streaming response
        reply = ""
        for chunk in response:
            if chunk.type == "content_block_delta":
                reply += chunk.delta.text
        
        return reply.strip()
    except Exception as e:
        print(f"Error: {e}")
        return f"ERROR: {e}"

# 其余代码保持不变...
if __name__ == "__main__":
    # Collect all .mlw file paths first
    file_paths = []
    for root, dirs, files in os.walk('./data/why3/no-lemma5'):
        for file in files:
            if file.endswith('.mlw'):
                file_path = os.path.join(root, file)
                file_paths.append(file_path)

    print(f"Found {len(file_paths)} .mlw files to process")

    # Process each file
    with SqliteDict('./data/why3/no-lemma5/check.db', autocommit=True) as db:
        for (i,file_path) in enumerate(file_paths):
            print(f"[{i}/{len(file_paths)}] Processing: {file_path}")
            rel_path = os.path.relpath(file_path, './data/why3/no-lemma5')
            original_path = os.path.join('./data/why3/common', rel_path)
            if rel_path in db and not db[rel_path].startswith("ERROR"):
                print(f"[{i}/{len(file_paths)}] Skipping (already exists): {rel_path}")
                continue

            result = check(original_path, file_path)
            print(result)
            print('\n\n')
            
            db[rel_path] = result
            db.commit()
            print(f"[{i}/{len(file_paths)}] Processed: {rel_path}")

        with open('./data/why3/no-lemma5/check.csv', 'w', encoding='utf-8', newline='') as file:
            writer = csv.writer(file)
            for (rel_path, result) in db.items():
                if result.startswith("CORRECT"):
                    continue
                original_path = os.path.join('./data/why3/common', rel_path)
                path = os.path.join('./data/why3/no-lemma5', rel_path)
                writer.writerow([rel_path, result])

    print("Processing complete!") 