#!/usr/bin/env python
Lean Proof Validation Script

This script loads the HuggingFace dataset 'dataset_path'
and validates each proof in the 'formal_ground_truth' column using Lean 4.
"""

import argparse
import time
import multiprocessing as mp
from typing import List, Tuple
import sys

try:
    from datasets import load_dataset
    from lean_interact import LeanREPLConfig, AutoLeanServer, Command, TempRequireProject
    from tqdm import tqdm
except ImportError as e:
    print(f"Error importing required packages: {e}")
    print("Please install required packages:")
    print("pip install datasets lean-interact tqdm")
    sys.exit(1)


def validate_proof(proof_data: Tuple[int, str]) -> Tuple[int, bool]:
    """
    Validate a single proof using AutoLeanServer.
    
    Args:
        proof_data: Tuple of (index, proof_string)
    
    Returns:
        Tuple of (index, is_valid)
    """
    idx, proof = proof_data
    
    try:
        # Create server instance for this process
        server = AutoLeanServer(config=global_config)
        
        # Run the proof
        response = server.run(Command(cmd=proof))
        
        # Check if proof is valid (no errors and no sorries)
        has_errors = any(msg.severity == 'error' for msg in response.messages)
        has_sorries = len(response.sorries) > 0
        
        is_valid = not (has_errors or has_sorries)
        
        return idx, is_valid
        
    except Exception as e:
        print(f"Error validating proof {idx}: {e}")
        return idx, False


def init_worker(config: LeanREPLConfig):
    """Initialize worker process with global config."""
    global global_config
    global_config = config


def main():
    parser = argparse.ArgumentParser(
        description="Validate Lean proofs from HuggingFace dataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--max-parallel", 
        type=int, 
        default=4,
        help="Maximum number of parallel processes"
    )
    parser.add_argument(
        "--num-rows",
        type=int,
        default=None,
        help="Number of rows to process (default: all rows)"
    )
    parser.add_argument(
        "--memory-limit",
        type=int,
        default=16384,  # 16 GB in MB
        help="Memory limit per REPL in MB"
    )
    
    args = parser.parse_args()
    
    print("Loading dataset...")
    try:
        dataset = load_dataset("dataset_path")
        data = dataset['train']  # Assuming the data is in 'train' split
    except Exception as e:
        print(f"Error loading dataset: {e}")
        sys.exit(1)
    
    # Limit number of rows if specified
    if args.num_rows is not None:
        data = data.select(range(min(args.num_rows, len(data))))
    
    print(f"Dataset loaded: {len(data)} rows to validate")
    
    # Create Lean REPL configuration
    print("Setting up Lean REPL configuration...")
    try:
        project = TempRequireProject(lean_version="v4.15.0", require="mathlib")
        config = LeanREPLConfig(
            project=project,
            verbose=True,
            memory_hard_limit_mb=args.memory_limit
        )
    except Exception as e:
        print(f"Error creating Lean configuration: {e}")
        sys.exit(1)
    
    # Prepare proof data for multiprocessing
    proof_data = [(i, row['formal_ground_truth']) for i, row in enumerate(data)]
    
    print(f"Starting validation with {args.max_parallel} parallel processes...")
    start_time = time.time()
    
    # Use multiprocessing to validate proofs
    with mp.Pool(
        processes=args.max_parallel,
        initializer=init_worker,
        initargs=(config,)
    ) as pool:
        # Use tqdm for progress bar
        results = []
        with tqdm(total=len(proof_data), desc="Validating proofs") as pbar:
            for result in pool.imap(validate_proof, proof_data):
                results.append(result)
                pbar.update(1)
    
    # Calculate statistics
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    # Sort results by index to maintain order
    results.sort(key=lambda x: x[0])
    
    valid_count = sum(1 for _, is_valid in results if is_valid)
    total_count = len(results)
    acceptance_rate = (valid_count / total_count) * 100 if total_count > 0 else 0
    
    # Print results
    print("\n" + "="*60)
    print("VALIDATION RESULTS")
    print("="*60)
    print(f"Total proofs validated: {total_count}")
    print(f"Valid proofs: {valid_count}")
    print(f"Invalid proofs: {total_count - valid_count}")
    print(f"Overall acceptance rate: {acceptance_rate:.2f}%")
    print(f"Total elapsed time: {elapsed_time:.2f} seconds")
    print(f"Average time per proof: {elapsed_time/total_count:.3f} seconds")
    print("="*60)
    
    # Optionally save detailed results
    invalid_indices = [idx for idx, is_valid in results if not is_valid]
    if invalid_indices:
        print(f"\nInvalid proof indices (first 10): {invalid_indices[:10]}")
        if len(invalid_indices) > 10:
            print(f"... and {len(invalid_indices) - 10} more")


if __name__ == "__main__":
    main()
