import pandas as pd
import re
import argparse
import os
from tqdm import tqdm

# --- Detection Functions ---

def detect_first_order(text: str) -> bool:
    """Detects first-order logic (FOL) patterns."""
    if not isinstance(text, str):
        return False
    
    # Patterns for quantifiers, advanced connectives, and predicate structures
    fol_patterns = [
        r'∀|∃',                                 
        r'→|⇒|↔|⇔',               
        r'\u2200|\u2203|\u00AC|\u2227|\u2228|\u2192|\u2194',
        r'\\forall|\\exists|\\neg|\\land|\\lor|\\cap'  # Unicode for ∀, ∃, ¬, ∧, ∨, ∩  
    ]
    combined_regex = re.compile('|'.join(fol_patterns))
    return bool(re.search(combined_regex, text))

def detect_propositional(text: str) -> bool:
    """Detects propositional logic patterns."""
    if not isinstance(text, str):
        return False
        
    # Patterns for basic connectives often used with propositional variables (e.g., P, Q)
    prop_patterns = [
        r'∧|∨|¬',                              
        r'->|==>|<=>',             
        r'\u2200|\u2203|\u00AC|\u2227|\u2228|\u2192|\u2194',   
        r'\b(P|Q|R|S)\s*(∧|∨|→|⇒)\s*(P|Q|R|S)\b', # Structure like P ∧ Q,
        r'be the proposition'
    ]
    combined_regex = re.compile('|'.join(prop_patterns))
    return bool(re.search(combined_regex, text))

def detect_semantic_triples(text: str) -> bool:
    """Detects explicitly labeled semantic triples (Subject-Predicate-Object)."""
    if not isinstance(text, str):
        return False
        
    # Patterns for labeled entities, case-insensitive
    triple_patterns = [
        r'(?i)\b(subject|predicate|object|sub|pred|obj)\s*[:=]',
        # Looks for tuple structure like ('Socrates', 'is a', 'man')
        r'\(\s*["\'].+?["\']\s*,\s*["\'].+?["\']\s*,\s*["\'].+?["\']\s*\)',
        r'subject|predicate|object'
    ]
    combined_regex = re.compile('|'.join(triple_patterns))
    return bool(re.search(combined_regex, text))

def detect_alists(text: str) -> bool:
    """Detects A-Lists in the text."""
    if not isinstance(text, str):
        return False
    
    alist_patterns = [
            r'\{\s*(?:\w+|[\$\?]\w+)\s*:\s*(?:[^{}]|\{(?:[^{}]|\n)*\})*?\}',
            # r'\{\s*[\w\$\?]+\s*:'
    ]
    combined_regex = re.compile('|'.join(alist_patterns))
    return bool(re.search(combined_regex, text))

def main():
    parser = argparse.ArgumentParser(
        description="Analyze a CSV file for specified logical patterns in model outputs."
    )
    parser.add_argument("--input_csv", type=str, default="input.csv", help="Path to the input CSV file.")
    parser.add_argument(
        "--logic_type",
        type=str,
        default="first-order",
        choices=["first-order", "propositional", "semantic-triples", "alist"],
        help="The type of logic to detect."
    )
    parser.add_argument(
        "--column",
        type=str,
        default="model_output",
        help="Name of the column containing the model's text output."
    )
    parser.add_argument(
        "--output_csv",
        type=str,
        default='output.csv',
        help="Path to save the output CSV. Defaults to 'input_filename_analyzed.csv'."
    )
    args = parser.parse_args()

    # Map logic_type argument to the corresponding detection function
    detection_functions = {
        "first-order": detect_first_order,
        "propositional": detect_propositional,
        "semantic-triples": detect_semantic_triples,
        "alist": detect_alists
    }
    detector = detection_functions[args.logic_type]
    output_column_name = f"{args.logic_type}_detected"

    if args.output_csv is None:
        base, ext = os.path.splitext(args.input_csv)
        args.output_csv = f"{base}_{args.logic_type}_analyzed{ext}"

    try:
        print(f"Reading data from '{args.input_csv}'...")
        df = pd.read_csv(args.input_csv)
    except FileNotFoundError:
        print(f"Error: The file '{args.input_csv}' was not found.")
        return

    if args.column not in df.columns:
        print(f"Error: Column '{args.column}' not found. Available columns: {list(df.columns)}")
        return

    print(f"Analyzing column '{args.column}' for '{args.logic_type}' patterns...")
    tqdm.pandas(desc=f"Detecting {args.logic_type} patterns")
    # Split the DataFrame into groups of 1000 rows and process each group
    group_size = 1000
    detected_results = []
    for start in tqdm(range(0, len(df), group_size), desc="Processing in batches"):
        end = min(start + group_size, len(df))
        batch = df.iloc[start:end]
        detected = batch[args.column].progress_apply(detector)
        detected_results.append(detected)
    df[output_column_name] = pd.concat(detected_results).reset_index(drop=True)

    detection_count = df[output_column_name].sum()
    total_rows = len(df)
    detection_percentage = (detection_count / total_rows) * 100 if total_rows > 0 else 0

    print("\n--- Analysis Complete ---")
    print(f"Logic Type: {args.logic_type}")
    print(f"Total rows processed: {total_rows}")
    print(f"Rows with patterns detected: {detection_count}")
    print(f"Detection Percentage: {detection_percentage:.2f}%")

    # Save the results
    print(f"Saving results to '{args.output_csv}'...")
    df.to_csv(args.output_csv, index=False)
    print("Done.")


if __name__ == "__main__":
    main()