#!/usr/bin/env python3
import os
import subprocess
import anthropic
import csv
from sqlitedict import SqliteDict

client = anthropic.Anthropic()

def norm(text):
    lines = text.strip().split('\n')
    if len(lines) >= 2 and lines[0].startswith('```') and lines[-1].startswith('```'):
        return '\n'.join(lines[1:-1]) + '\n'
    return text + '\n'

def diff(path, path2):
    proc = subprocess.run(['diff', '-u', path, path2], capture_output=True, text=True)
    if proc.returncode == 0:
        return "CORRECT"
    return proc.stdout

def check(path, path2):
    with open(path, 'r', encoding='utf-8') as file:
        original_text = file.read()
    # Calculate diff between original and processed files
    proc = subprocess.run(['diff', '-u', path, path2], capture_output=True, text=True)
    if proc.returncode == 0:
        return "CORRECT"
    diff = proc.stdout
    
    # 分离系统提示和具体指令
    system_prompt = """You are an expert in formal verification, especially in the verifier Why3. You are helping to review and validate code transformations to ensure they preserve program correctness."""
    
    task_description = """You are given a DIFF between original Why3 code and processed code where lemma declarations and assertion annotations were supposed to be removed. Your task is to check if the transformation was done correctly - that is, whether only lemmas and assertions were removed while preserving all essential program constructs and specifications."""
    
    what_should_be_removed = """
    What SHOULD be removed (and is correct to remove):
    1. Lemma declarations:
       - Simple lemmas: "lemma <name>: <formula>"
       - Let lemmas: "let lemma <name> (<params>) : <type> = ..."
       - Recursive lemmas: "let rec lemma <name> (<params>) : <type> = ..."
    
    2. Assertion annotations:
       - Basic assertions: "assert { <formula> }"
       - Assertions with reasons: "assert { <formula> by <reason> }"
       - Multi-line assertions within program code

    What CAN be removed (and is okay to remove):
    Any comments, including doc comments.
    """
    
    what_should_be_preserved = """
    What MUST be preserved (and is wrong to remove):
    1. Function/procedure specifications:
       - requires clauses of programs: "let fun_name ... requires { <precondition> }"
       - ensures clauses of programs: "let fun_name ... ensures { <postcondition> }"
       - variant clauses of programs: "let fun_name ... variant { <termination_measure> }"
    
    2. Loop specifications:
       - invariant clauses: "invariant { <loop_invariant> }"
       - variant clauses in loops
    
    3. All functional code:
       - Function definitions, implementations
       - Type definitions, module declarations
       - Import statements, clones
       - Any program logic and control structures
    """
    
    requirements = """
    Your task:
    1. Analyze the provided diff between original and processed Why3 code
    2. Check if the transformation correctly removed ONLY lemmas and assertions
    3. Verify that all program specifications (requires/ensures/invariants) are preserved
    4. Identify any incorrectly removed content that should have been kept
    5. Identify any content that should have been removed but wasn't
    6. Return your analysis as: "CORRECT" if transformation is valid, or "INCORRECT: <specific issues>" if there are problems
    7. Be specific about what was wrongly removed or what should have been removed
    """
    
    # 组合完整的提示词
    full_prompt = f"""{system_prompt}

{task_description}

{requirements}

{what_should_be_removed}

{what_should_be_preserved}

DIFF TO ANALYZE:
```
{diff}
```

Original Why3 code:
```
{original_text}
```

Please analyze the diff and return either "CORRECT" or "INCORRECT: <specific issues>". Do not output any other text."""
    
    try:
        response = client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=40000,
            temperature=0,
            stream=True,
            messages=[
                {
                    "role": "user",
                    "content": full_prompt
                }
            ]
        )
        
        # Handle streaming response
        reply = ""
        for chunk in response:
            if chunk.type == "content_block_delta":
                reply += chunk.delta.text
        
        return reply.strip()
    except Exception as e:
        print(f"Error: {e}")
        return f"ERROR: {e}"

# Collect all .mlw file paths first
file_paths = []
for root, dirs, files in os.walk('./data/why3/no-lemma3'):
    for file in files:
        if file.endswith('.mlw'):
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

print(f"Found {len(file_paths)} .mlw files to process")

# Process each file
with SqliteDict('./data/why3/no-lemma3/check.db', autocommit=True) as db:
    for (i,file_path) in enumerate(file_paths):
        print(f"[{i}/{len(file_paths)}] Processing: {file_path}")
        rel_path = os.path.relpath(file_path, './data/why3/no-lemma3')
        original_path = os.path.join('./data/why3/common', rel_path)
        if rel_path in db:
            print(f"[{i}/{len(file_paths)}] Skipping (already exists): {rel_path}")
            continue

        result = check(original_path, file_path)
        print(result)
        print('\n\n')
        
        db[rel_path] = result
        db.commit()
        print(f"[{i}/{len(file_paths)}] Processed: {rel_path}")

    with open('./data/why3/no-lemma3/check.csv', 'w', encoding='utf-8', newline='') as file:
        writer = csv.writer(file)
        for (rel_path, result) in db.items():
            if result.startswith("CORRECT"):
                continue
            original_path = os.path.join('./data/why3/common', rel_path)
            path = os.path.join('./data/why3/no-lemma3', rel_path)
            #diff_result = diff(original_path, path)
            writer.writerow([rel_path, result])

print("Processing complete!")


