#!/usr/bin/env python3
import os
import subprocess
import anthropic
import csv
from sqlitedict import SqliteDict

client = anthropic.Anthropic()

def norm(text):
    lines = text.strip().split('\n')
    if len(lines) >= 2 and lines[0].startswith('```') and lines[-1].startswith('```'):
        return '\n'.join(lines[1:-1]) + '\n'
    return text + '\n'

def diff(path, path2):
    proc = subprocess.run(['diff', '-u', path, path2], capture_output=True, text=True)
    if proc.returncode == 0:
        return "CORRECT"
    return proc.stdout

def check(path, path2):
    with open(path, 'r', encoding='utf-8') as file:
        original_text = file.read()
    # Calculate diff between original and processed files
    proc = subprocess.run(['diff', '-u', path, path2], capture_output=True, text=True)
    if proc.returncode == 0:
        return "CORRECT"
    diff = proc.stdout
    
    # 分离系统提示和具体指令
    system_prompt = """You are an expert in formal verification, especially in the verifier Why3. You are helping to review and validate code transformations to ensure they preserve program correctness."""
    
    task_description = """You are given a DIFF between original Why3 code and processed code where lemmas and assertion annotations were supposed to be removed or transformed. Your task is to check if the transformation was done correctly - that is, whether only lemmas and assertions were removed or transformed while preserving all essential program constructs and specifications."""

    syntax_patterns = """
    What SHOULD be removed or transformed (and is correct to remove or transform):
    
    1. Remove assertion annotations:
       - Basic assertions: "assert { <formula> }"
       - Assertions with reasons: "assert { <formula> by <reason> }"
       - Multi-line assertions with complex formulas
       Examples:
         assert { s = 22 }
         assert { exchange (old a) a i j }
         assert { !j < !i \/ !j = !i = m \/ !j = !i = n }
         assert { forall i. 0 <= i < i0 -> ... }

    2. Transform short lemma declarations into ghost declarations:
       - "lemma <name>: <formula>" -> "let ghost <name> <forall params> : unit <requires> <ensures> = ()"
       Examples:
       (a).
         Input Code:
         ```
         lemma grid_eq_sub:
           forall g1 g2 a b. 0 <= a <= 81 -> 0 <= b <= 81 ->
             grid_eq g1 g2 -> grid_eq_sub g1 g2 a b
         ```
         Output:
         ```
         let ghost grid_eq_sub g1 g2 a b : unit
           requires { 0 <= a <= 81 }
           requires { 0 <= b <= 81 }
           requires { grid_eq g1 g2 }
           ensures  { grid_eq_sub g1 g2 a b }
         = ()
         ```
       (b).
         Input Code:
         ```
         lemma gcd_sub: forall a b: int. gcd a b = gcd a (b - a)
         ```
         Output:
         ```
         let ghost gcd_sub (a b: int) : unit
           ensures { gcd a b = gcd a (b - a) }
         = ()
         ```
        (c).
         Input Code:
         ```
         lemma closest :
           forall a b c: int.
           abs (2 * a * b - 2 * c) <= a ->
           forall b': int. abs (a * b - c) <= abs (a * b' - c)
         ```
         Output:
         ```
         let ghost closest (a b c: int) : unit
           requires { abs (2 * a * b - 2 * c) <= a }
           ensures  { forall b': int. abs (a * b - c) <= abs (a * b' - c) }
         = ()
         ```
        (d).
         Input Code:
         ```
         lemma while_rule_ext:
           forall e:term, inv inv':fmla, i:stmt.
           valid_fmla (Fimplies inv' inv) ->
           valid_triple (Fand (Fterm e) inv') i inv' ->
           valid_triple inv' (Swhile e inv i) (Fand (Fnot (Fterm e)) inv')
         ```
         Output:
         ```
         let ghost while_rule_ext (e:term) (inv inv':fmla) (i:stmt) : unit
           requires { valid_fmla (Fimplies inv' inv) }
           requires { valid_triple (Fand (Fterm e) inv') i inv' }
           ensures  { valid_triple inv' (Swhile e inv i) (Fand (Fnot (Fterm e)) inv') }
         = ()
         ```

    3. Remove proofs of long lemmas:
       Transform the code based on the following rules.
       - "let lemma <name> <params> <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       - "let rec lemma <name> <params> <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       - "let lemma <name> <params> : unit <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       - "let rec lemma <name> <params> : unit <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       Zero or more requires/ensures/variant clauses can occur in <spec>. You should preserve all requires/ensures clauses, but remove the variant clauses.
       Examples:
       (a).
         Input Code:
         ```
         let lemma full_up_to_frame_all (g1 g2:grid) (i:int)
            requires { 0 <= i <= 81 }
            requires { grid_eq g1 g2 }
            requires { full_up_to g1 i }
            ensures  { full_up_to g2 i }
          = assert { grid_eq_sub g1 g2 0 i }
         ```
         Output:
         ```
         let ghost full_up_to_frame_all (g1 g2:grid) (i:int) : unit
            requires { 0 <= i <= 81 }
            requires { grid_eq g1 g2 }
            requires { full_up_to g1 i }
            ensures  { full_up_to g2 i }
          = ()
         ```
       (b).
         Input Code:
         ```
         let rec lemma map_eq_shift_zero (x y: map int 'a) (n m: int)
           requires { map_eq_sub_shift x y n n (m-n) }
           variant { m - n }
           ensures { MapEq.map_eq_sub x y n m }
          =
           if n < m then
           begin
           assert { forall i. 0 <= i < m-n -> x[n+i] = y[n+i] };
           assert { forall i. n <= i < m ->
                       let j = i - n in 0 <= j < m-n ->
                           x[n+j] = y[n+j] -> x[i] = y[i]};
           map_eq_shift_zero x y (n+1) m;
           end
           else ()
         ```
         Output:
         ```
         let ghost map_eq_shift_zero (x y: map int 'a) (n m: int) : unit
           requires { map_eq_sub_shift x y n n (m-n) }
           ensures { MapEq.map_eq_sub x y n m }
         = ()
         ```
       (c).
         Input Code:
         ```
         let lemma auxT1 (v i j:int) : unit
           requires { 0 <= i < zeros v }
           requires { 0 <= j < c }
           ensures { cl2 j i <= v }
          = if cl2 j i > v then begin
            permut_numoff (cl1 j) (cl2 j) 0 n ((>=) v);
            assert {
              let a = numof (below_column e j v) 0 i in
              let b = numof (below_column e j v) i n in
              numof (below_column e j v) 0 n = a + b
              && a <= i
              && b > 0
            };
            auxT v i j
            end
         ```
         Output:
         ```
         let ghost auxT1 (v i j:int) : unit
           requires { 0 <= i < zeros v }
           requires { 0 <= j < c }
           ensures { cl2 j i <= v }
         = ()
         ```
       (d).
         Input Code:
         ```
         let rec lemma sum_mult (f:int -> int) (a b l:int) : unit
           ensures { sum (smulf f l) a b = l * sum f a b }
           variant { b - a }
         = if b > a then sum_mult f a (b-1) l
         ```
         Output:
         ```
         let ghost sum_mult (f:int -> int) (a b l:int) : unit
           ensures { sum (smulf f l) a b = l * sum f a b }
         = ()
         ```

    4. Transform lemma declarations that return non-unit:
       Some lemma may have return values other than unit. For those lemmas, you should only replace `lemma` keyword to `ghost` keyword.
       - "let lemma <name> <params> : <type> <spec> = <proof>" -> "let ghost <name> <params> : <type> <spec> = <proof>"
       - "let rec lemma <name> <params> : <type> <spec> = <proof>" -> "let rec ghost <name> <params> : <type> <spec> = <proof>"
       Note: Specially for this case only, preserve the 'rec' keyword in the output.
       Examples:
       (a).
         Input Code:
         ```
         let lemma prime_factor (n: int) : (p: int)
           requires { n >= 2 }
           ensures  { prime p }
           ensures  { divides p n }
         = for p = 2 to n do
             invariant { forall d. 2 <= d < p -> not (divides d n) }
             if mod n p = 0 then return p
           done;
           return n
         ```
         Output:
         ```
         let ghost prime_factor (n: int) : (p: int)
           requires { n >= 2 }
           ensures  { prime p }
           ensures  { divides p n }
         = for p = 2 to n do
             invariant { forall d. 2 <= d < p -> not (divides d n) }
             if mod n p = 0 then return p
           done;
           return n
         ```
       (b).
         Input Code:
         ```
         let lemma lt_total (p q: permutation) : bool
           requires { length p = length q }
           ensures  { if result then lt p q else p = q \/ lt q p }
         = let n = length p in
           for i = 0 to n-1 do
             invariant { forall j. 0 <= j < i -> p[j] = q[j] }
             if p[i] < q[i] then return true;
             if p[i] > q[i] then return false;
           done;
           return false
         ```
         Output:
         ```
         let ghost lt_total (p q: permutation) : bool
           requires { length p = length q }
           ensures  { if result then lt p q else p = q \/ lt q p }
         = let n = length p in
           for i = 0 to n-1 do
             invariant { forall j. 0 <= j < i -> p[j] = q[j] }
             if p[i] < q[i] then return true;
             if p[i] > q[i] then return false;
           done;
           return false
         ```

    5. Remove lemma applications:
       You will be given a list of lemma names. Remove all the applications of these lemmas from the code.
       Examples:
       (a).
         Input Code:
         ```
         for l = i0+1 to n - 1 do
           invariant { forall k. i0+1 <= k < l -> p[k] = a1[k] }
           if p[l] <> a1[l] then (
             assert { a1[l] < p[l] };
             occ_id a1 p 0 l; occ_split a1 0 l n; occ_split p 0 l n;
             call_function(p);
             occ_at a1 0 l n; occ_at p 0 l n;
             assert { occ p[l] a1 l n = occ p[l] p l n > 0 };
             return 3;
           )
         done
         ```
         Lemma Names:
         ```
         occ_id
         occ_split
         occ_at
         ```
         Output:
         ```
         for l = i0+1 to n - 1 do
           invariant { forall k. i0+1 <= k < l -> p[k] = a1[k] }
           if p[l] <> a1[l] then (
             call_function(p);
             return 3;
           )
         done
         ```
       (b).
         Input Code:
         ```
         while i <= n - m do
           invariant { 0 <= i <= n }
           invariant { forall j. 0 <= j < i -> j <= n - m ->
                       substring text j m <> pat }
           variant   { n - m - i }
           if occurs pat text i then return i;
           if i = n - m then break;
           let c = text[i + m] in
           i <- i + if M.mem c bst.sht then (shift_lemma bst text i; M.find c bst.sht)
                                       else (no_shift    bst text i; m + 1)
         done
         ```         Lemma Names:
         ```
         shift_lemma
         no_shift
         ```
         Output:
         ```
         while i <= n - m do
           invariant { 0 <= i <= n }
           invariant { forall j. 0 <= j < i -> j <= n - m ->
                       substring text j m <> pat }
           variant   { n - m - i }
           if occurs pat text i then return i;
           if i = n - m then break;
           let c = text[i + m] in
           i <- i + if M.mem c bst.sht then M.find c bst.sht
                                       else m + 1
         done
         ```
    """
    
    what_should_be_preserved = """
    What MUST be preserved (and is wrong to remove):
    1. Function/procedure specifications:
       - requires clauses of programs: "let fun_name ... requires { <precondition> }"
       - ensures clauses of programs: "let fun_name ... ensures { <postcondition> }"
       - variant clauses of programs: "let fun_name ... variant { <termination_measure> }"
    
    2. Loop specifications:
       - invariant clauses: "invariant { <loop_invariant> }"
       - variant clauses in loops
    
    3. All functional code:
       - Function definitions, implementations
       - Type definitions, module declarations
       - Import statements, clones
       - Any program logic and control structures
    
    4. Specifications:
       - The specification of every function/procedure:
       - Even though the syntax of lemma declarations is changed, the specification of every lemma should be semantically preserved.
    """
    
    requirements = """
    Your task:
    1. Analyze the provided diff between original and processed Why3 code
    2. Check if the transformation correctly removed or tranformed ONLY lemmas and assertions
    3. Verify that all program specifications (requires/ensures/invariants) are preserved
    4. Identify any incorrect removal and transformation that should have been kept
    5. Identify any content that should have been removed but wasn't
    6. Return your analysis as: "CORRECT" if transformation is valid, or "INCORRECT: <specific issues>" if there are problems
    7. Be specific about what was wrongly removed or what should have been removed
    """
    
    # 组合完整的提示词
    full_prompt = f"""{system_prompt}

{task_description}

{requirements}

{syntax_patterns}

{what_should_be_preserved}

DIFF TO ANALYZE:
```
{diff}
```

Original Why3 code:
```
{original_text}
```

Please analyze the diff and return either "CORRECT" or "INCORRECT: <specific issues>". Do not output any other text."""
    
    try:
        response = client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=40000,
            temperature=0,
            stream=True,
            messages=[
                {
                    "role": "user",
                    "content": full_prompt
                }
            ]
        )
        
        # Handle streaming response
        reply = ""
        for chunk in response:
            if chunk.type == "content_block_delta":
                reply += chunk.delta.text
        
        return reply.strip()
    except Exception as e:
        print(f"Error: {e}")
        return f"ERROR: {e}"

# Collect all .mlw file paths first
file_paths = []
for root, dirs, files in os.walk('./data/why3/no-lemma5'):
    for file in files:
        if file.endswith('.mlw'):
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

print(f"Found {len(file_paths)} .mlw files to process")

# Process each file
with SqliteDict('./data/why3/no-lemma5/check.db', autocommit=True) as db:
    for (i,file_path) in enumerate(file_paths):
        print(f"[{i}/{len(file_paths)}] Processing: {file_path}")
        rel_path = os.path.relpath(file_path, './data/why3/no-lemma5')
        original_path = os.path.join('./data/why3/common', rel_path)
        if rel_path in db and not db[rel_path].startswith("Error"):
            print(f"[{i}/{len(file_paths)}] Skipping (already exists): {rel_path}")
            continue

        result = check(original_path, file_path)
        print(result)
        print('\n\n')
        
        db[rel_path] = result
        db.commit()
        print(f"[{i}/{len(file_paths)}] Processed: {rel_path}")

    with open('./data/why3/no-lemma5/check.csv', 'w', encoding='utf-8', newline='') as file:
        writer = csv.writer(file)
        for (rel_path, result) in db.items():
            if result.startswith("CORRECT"):
                continue
            original_path = os.path.join('./data/why3/common', rel_path)
            path = os.path.join('./data/why3/no-lemma5', rel_path)
            #diff_result = diff(original_path, path)
            writer.writerow([rel_path, result])

print("Processing complete!")


