#!/usr/bin/env python3

import re
import os
import sys
from sqlitedict import SqliteDict

lemma_re = [r'let\s+lemma\s+([\w\']+)', r'let\s+rec\s+lemma\s+([\w\']+)', r'^\s*lemma\s+([\w\']+)']
decl_re = [r'let(\s+rec)?(\s+ghost)?(\s+function)?\s+([\w\']+)']

def collect(text, path):
    lemma_names = []
    decl_names = []

    for pattern in lemma_re:
        matches = re.findall(pattern, text)
        lemma_names.extend(matches)
    
    for pattern in decl_re:
        matches = re.findall(pattern, text)
        # Only capture the third group (declaration name)
        decl_names.extend([match[3] for match in matches])
    
    # Check for overlap between lemma_names and decl_names
    overlap = set(lemma_names) & set(decl_names)
    if overlap:
        print(f"{path} OVERLAP: {overlap}")

    directory = os.path.dirname(path)
    if directory not in ['./data/why3/common', 'data/why3/common', './data/why3/no-lemma4', 'data/why3/no-lemma4']:
        with SqliteDict(f'{directory}/lemma.db', autocommit=True) as db:
            for lemma_name in lemma_names:
                if lemma_name not in db:
                    db[lemma_name] = ()
                    db.commit()
            for l in db.keys():
                if l in decl_names:
                    print(f"{path} OVERLAP: {l}")

    return lemma_names

def collect_all(path):
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith('.mlw'):
                with open(os.path.join(root, file), 'r', encoding='utf-8') as f:
                    text = f.read()
                    lemma_names = collect(text, os.path.join(root, file))
                    short_lemmas = [l for l in lemma_names if len(l) <= 5]
                    if short_lemmas:
                        print(f"{os.path.join(root, file)}: {short_lemmas}")


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python collect_lemma.py <directory_path>")
        sys.exit(1)

    directory_path = sys.argv[1]

    if not os.path.isdir(directory_path):
        print(f"Error: {directory_path} is not a valid directory")
        sys.exit(1)

    collect_all(directory_path)

