#!/usr/bin/env python3
import os
import anthropic
import re
from sqlitedict import SqliteDict
import textwrap
import time

lemma_re = [r'let\s+lemma\s+([\w\']+)', r'let\s+rec\s+lemma\s+([\w\']+)', r'^\s*lemma\s+([\w\']+)']
decl_re = [r'let(\s+rec)?(\s+ghost)?(\s+function)?\s+([\w\']+)']

def collect(text, path):
    lemma_names = set()

    for pattern in lemma_re:
        matches = re.findall(pattern, text)
        lemma_names.update(matches)
    
    directory = os.path.dirname(path)
    if directory not in ['./data/why3/common', 'data/why3/common', './data/why3/no-lemma5', 'data/why3/no-lemma5']:
        with SqliteDict(f'{directory}/lemma.db', autocommit=True) as db:
            for l in db.keys():
                lemma_names.add(l)

    return lemma_names


client = anthropic.Anthropic()

def norm(text):
    lines = text.strip().split('\n')
    if len(lines) >= 2 and lines[0].startswith('```') and lines[-1].startswith('```'):
        return '\n'.join(lines[1:-1]) + '\n'
    return text + '\n'

def clean(path):
    with open(path, 'r', encoding='utf-8') as file:
        text = file.read()
    
    lemma_names = '\n'.join(collect(text, path))

    # 分离系统提示和具体指令
    system_prompt = """You are an expert in formal verification, especially in the verifier Why3."""
    
    task_description = """You need to remove all the assertion annotations, lemma proofs, and lemma applications. Then you should return the obtained code."""
    
    syntax_patterns = """
    Syntax Patterns to Remove or Transform:
    
    1. Remove assertion annotations:
       - Basic assertions: "assert { <formula> }"
       - Assertions with reasons: "assert { <formula> by <reason> }"
       - Multi-line assertions with complex formulas
       Examples:
         assert { s = 22 }
         assert { exchange (old a) a i j }
         assert { !j < !i \/ !j = !i = m \/ !j = !i = n }
         assert { forall i. 0 <= i < i0 -> ... }

    2. Transform short lemma declarations into ghost declarations:
       - "lemma <name>: <formula>" -> "let ghost <name> <forall params> : unit <requires> <ensures> = ()"
       Examples:
       (a).
         Input Code:
         ```
         lemma grid_eq_sub:
           forall g1 g2 a b. 0 <= a <= 81 -> 0 <= b <= 81 ->
             grid_eq g1 g2 -> grid_eq_sub g1 g2 a b
         ```
         Output:
         ```
         let ghost grid_eq_sub g1 g2 a b : unit
           requires { 0 <= a <= 81 }
           requires { 0 <= b <= 81 }
           requires { grid_eq g1 g2 }
           ensures  { grid_eq_sub g1 g2 a b }
         = ()
         ```
       (b).
         Input Code:
         ```
         lemma gcd_sub: forall a b: int. gcd a b = gcd a (b - a)
         ```
         Output:
         ```
         let ghost gcd_sub (a b: int) : unit
           ensures { gcd a b = gcd a (b - a) }
         = ()
         ```
        (c).
         Input Code:
         ```
         lemma closest :
           forall a b c: int.
           abs (2 * a * b - 2 * c) <= a ->
           forall b': int. abs (a * b - c) <= abs (a * b' - c)
         ```
         Output:
         ```
         let ghost closest (a b c: int) : unit
           requires { abs (2 * a * b - 2 * c) <= a }
           ensures  { forall b': int. abs (a * b - c) <= abs (a * b' - c) }
         = ()
         ```
        (d).
         Input Code:
         ```
         lemma while_rule_ext:
           forall e:term, inv inv':fmla, i:stmt.
           valid_fmla (Fimplies inv' inv) ->
           valid_triple (Fand (Fterm e) inv') i inv' ->
           valid_triple inv' (Swhile e inv i) (Fand (Fnot (Fterm e)) inv')
         ```
         Output:
         ```
         let ghost while_rule_ext (e:term) (inv inv':fmla) (i:stmt) : unit
           requires { valid_fmla (Fimplies inv' inv) }
           requires { valid_triple (Fand (Fterm e) inv') i inv' }
           ensures  { valid_triple inv' (Swhile e inv i) (Fand (Fnot (Fterm e)) inv') }
         = ()
         ```

    3. Remove proofs of long lemmas:
       Transform the code based on the following rules.
       - "let lemma <name> <params> <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       - "let rec lemma <name> <params> <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       - "let lemma <name> <params> : unit <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       - "let rec lemma <name> <params> : unit <spec> = <proof>" -> "let ghost <name> <params> : unit <spec without variant> = ()"
       Zero or more requires/ensures/variant clauses can occur in <spec>. You should preserve all requires/ensures clauses, but remove the variant clauses.
       Examples:
       (a).
         Input Code:
         ```
         let lemma full_up_to_frame_all (g1 g2:grid) (i:int)
            requires { 0 <= i <= 81 }
            requires { grid_eq g1 g2 }
            requires { full_up_to g1 i }
            ensures  { full_up_to g2 i }
          = assert { grid_eq_sub g1 g2 0 i }
         ```
         Output:
         ```
         let ghost full_up_to_frame_all (g1 g2:grid) (i:int) : unit
            requires { 0 <= i <= 81 }
            requires { grid_eq g1 g2 }
            requires { full_up_to g1 i }
            ensures  { full_up_to g2 i }
          = ()
         ```
       (b).
         Input Code:
         ```
         let rec lemma map_eq_shift_zero (x y: map int 'a) (n m: int)
           requires { map_eq_sub_shift x y n n (m-n) }
           variant { m - n }
           ensures { MapEq.map_eq_sub x y n m }
          =
           if n < m then
           begin
           assert { forall i. 0 <= i < m-n -> x[n+i] = y[n+i] };
           assert { forall i. n <= i < m ->
                       let j = i - n in 0 <= j < m-n ->
                           x[n+j] = y[n+j] -> x[i] = y[i]};
           map_eq_shift_zero x y (n+1) m;
           end
           else ()
         ```
         Output:
         ```
         let ghost map_eq_shift_zero (x y: map int 'a) (n m: int) : unit
           requires { map_eq_sub_shift x y n n (m-n) }
           ensures { MapEq.map_eq_sub x y n m }
         = ()
         ```
       (c).
         Input Code:
         ```
         let lemma auxT1 (v i j:int) : unit
           requires { 0 <= i < zeros v }
           requires { 0 <= j < c }
           ensures { cl2 j i <= v }
          = if cl2 j i > v then begin
            permut_numoff (cl1 j) (cl2 j) 0 n ((>=) v);
            assert {
              let a = numof (below_column e j v) 0 i in
              let b = numof (below_column e j v) i n in
              numof (below_column e j v) 0 n = a + b
              && a <= i
              && b > 0
            };
            auxT v i j
            end
         ```
         Output:
         ```
         let ghost auxT1 (v i j:int) : unit
           requires { 0 <= i < zeros v }
           requires { 0 <= j < c }
           ensures { cl2 j i <= v }
         = ()
         ```
       (d).
         Input Code:
         ```
         let rec lemma sum_mult (f:int -> int) (a b l:int) : unit
           ensures { sum (smulf f l) a b = l * sum f a b }
           variant { b - a }
         = if b > a then sum_mult f a (b-1) l
         ```
         Output:
         ```
         let ghost sum_mult (f:int -> int) (a b l:int) : unit
           ensures { sum (smulf f l) a b = l * sum f a b }
         = ()
         ```

    4. Transform lemma declarations that return non-unit:
       Some lemma may have return values other than unit. For those lemmas, you should only replace `lemma` keyword to `ghost` keyword.
       - "let lemma <name> <params> : <type> <spec> = <proof>" -> "let ghost <name> <params> : <type> <spec> = <proof>"
       - "let rec lemma <name> <params> : <type> <spec> = <proof>" -> "let rec ghost <name> <params> : <type> <spec> = <proof>"
       Note: Specially for this case only, preserve the 'rec' keyword in the output.
       Examples:
       (a).
         Input Code:
         ```
         let lemma prime_factor (n: int) : (p: int)
           requires { n >= 2 }
           ensures  { prime p }
           ensures  { divides p n }
         = for p = 2 to n do
             invariant { forall d. 2 <= d < p -> not (divides d n) }
             if mod n p = 0 then return p
           done;
           return n
         ```
         Output:
         ```
         let ghost prime_factor (n: int) : (p: int)
           requires { n >= 2 }
           ensures  { prime p }
           ensures  { divides p n }
         = for p = 2 to n do
             invariant { forall d. 2 <= d < p -> not (divides d n) }
             if mod n p = 0 then return p
           done;
           return n
         ```
       (b).
         Input Code:
         ```
         let lemma lt_total (p q: permutation) : bool
           requires { length p = length q }
           ensures  { if result then lt p q else p = q \/ lt q p }
         = let n = length p in
           for i = 0 to n-1 do
             invariant { forall j. 0 <= j < i -> p[j] = q[j] }
             if p[i] < q[i] then return true;
             if p[i] > q[i] then return false;
           done;
           return false
         ```
         Output:
         ```
         let ghost lt_total (p q: permutation) : bool
           requires { length p = length q }
           ensures  { if result then lt p q else p = q \/ lt q p }
         = let n = length p in
           for i = 0 to n-1 do
             invariant { forall j. 0 <= j < i -> p[j] = q[j] }
             if p[i] < q[i] then return true;
             if p[i] > q[i] then return false;
           done;
           return false
         ```

    5. Remove lemma applications:
       You will be given a list of lemma names. Remove all the applications of these lemmas from the code.
       Examples:
       (a).
         Input Code:
         ```
         for l = i0+1 to n - 1 do
           invariant { forall k. i0+1 <= k < l -> p[k] = a1[k] }
           if p[l] <> a1[l] then (
             assert { a1[l] < p[l] };
             occ_id a1 p 0 l; occ_split a1 0 l n; occ_split p 0 l n;
             call_function(p);
             occ_at a1 0 l n; occ_at p 0 l n;
             assert { occ p[l] a1 l n = occ p[l] p l n > 0 };
             return 3;
           )
         done
         ```
         Lemma Names:
         ```
         occ_id
         occ_split
         occ_at
         ```
         Output:
         ```
         for l = i0+1 to n - 1 do
           invariant { forall k. i0+1 <= k < l -> p[k] = a1[k] }
           if p[l] <> a1[l] then (
             call_function(p);
             return 3;
           )
         done
         ```
       (b).
         Input Code:
         ```
         while i <= n - m do
           invariant { 0 <= i <= n }
           invariant { forall j. 0 <= j < i -> j <= n - m ->
                       substring text j m <> pat }
           variant   { n - m - i }
           if occurs pat text i then return i;
           if i = n - m then break;
           let c = text[i + m] in
           i <- i + if M.mem c bst.sht then (shift_lemma bst text i; M.find c bst.sht)
                                       else (no_shift    bst text i; m + 1)
         done
         ```         Lemma Names:
         ```
         shift_lemma
         no_shift
         ```
         Output:
         ```
         while i <= n - m do
           invariant { 0 <= i <= n }
           invariant { forall j. 0 <= j < i -> j <= n - m ->
                       substring text j m <> pat }
           variant   { n - m - i }
           if occurs pat text i then return i;
           if i = n - m then break;
           let c = text[i + m] in
           i <- i + if M.mem c bst.sht then M.find c bst.sht
                                       else m + 1
         done
         ```
    """
    
    requirements =\
    """
    Requirements:
    1. The aim is to test Why3's automation capability without any lemma hints or assertion annotations.
    2. You must ensure that the code is still syntactic correct and any program inside it is semantically unchanged after removing the lemma declarations and assertion annotations. Removing assertion does not change the semantics.
    3. You must preserve all ensures/requires clauses.
    4. You must return the obtained code only, without any other text.
    5. You definitely cannot output any text other than the code.
    """
    
    # 组合完整的提示词
    if lemma_names:
        lemma_name_opt = f"Lemma Names:\n```\n{textwrap.dedent(lemma_names)}\n```\n\n\n"
    else:
        lemma_name_opt = ""
    full_prompt = f"{textwrap.dedent(system_prompt)}\n\n{textwrap.dedent(task_description)}\n{textwrap.dedent(syntax_patterns)}\n{textwrap.dedent(requirements)}\n\nInput Code:\n\n```\n{textwrap.dedent(text)}\n```\n\n{lemma_name_opt}Generate the Output."
    
    while True:
        try:
            response = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=40000,
                temperature=0,
                stream=True,
                messages=[
                    {
                        "role": "user",
                        "content": full_prompt
                    }
                ]
            )
            
            # Handle streaming response
            reply = ""
            for chunk in response:
                if chunk.type == "content_block_delta":
                    reply += chunk.delta.text
            
            return norm(reply.strip())
        except Exception as e:
            print(f"Error: {e}")
            print("Retrying in 10 seconds...")
            time.sleep(10)

# Collect all .mlw file paths first
file_paths = []
for root, dirs, files in os.walk('./data/why3/common'):
    for file in files:
        if file.endswith('.mlw'):
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

print(f"Found {len(file_paths)} .mlw files to process")

# Process each file
for (i,file_path) in enumerate(file_paths):
    print(f"[{i}/{len(file_paths)}] Processing: {file_path}")
    rel_path = os.path.relpath(file_path, './data/why3/common')
    output_path = os.path.join('./data/why3/no-lemma5', rel_path)
    if os.path.exists(output_path):
        print(f"[{i}/{len(file_paths)}] Skipping (already exists): {output_path}")
        continue

    obtained =  clean(file_path)
    print(obtained)
    print('\n\n\n')
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as file:
        file.write(obtained)

    print(f"[{i}/{len(file_paths)}] Saved: {output_path}")

print("Processing complete!")



