#!/usr/bin/env python3
"""
LeanClient Proof Validation Script

This script loads the HuggingFace dataset 'dataset_path'
and validates each proof in the 'formal_ground_truth' column using LeanClient.
"""

import argparse
import time
import sys
from pathlib import Path

try:
    from datasets import load_dataset
    from leanclient import LeanClientPool
except ImportError as e:
    print(f"Error importing required packages: {e}")
    print("Please install required packages:")
    print("pip install datasets leanclient")
    sys.exit(1)


def compiles_ok(client):
    """
    Check if a Lean proof compiles successfully.
    
    Args:
        client: LeanClient instance
    
    Returns:
        Tuple of (file_path, is_valid)
    """
    diags = client.open_file(timeout=120)
    has_error = any(d.get("severity") == 1 for d in diags)
    has_sorry = any(
        d.get("severity") == 2 and d.get("message") == "declaration uses 'sorry'"
        for d in diags
    )
    return client.file_path, (not has_error) and (not has_sorry)


def main():
    parser = argparse.ArgumentParser(
        description="Validate Lean proofs from HuggingFace dataset using LeanClient",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--project-path",
        type=str,
        required=True,
        help="Path to existing mathlib project directory"
    )
    parser.add_argument(
        "--num-workers", 
        type=int, 
        default=10,
        help="Number of worker processes"
    )
    parser.add_argument(
        "--num-rows",
        type=int,
        default=100,
        help="Number of rows to process from dataset"
    )
    parser.add_argument(
        "--scratch-dir",
        type=str,
        default="ScratchProofs",
        help="Directory name for temporary proof files (relative to project)"
    )
    
    args = parser.parse_args()
    
    # Validate project path
    project_path = Path(args.project_path)
    if not project_path.exists():
        print(f"Error: Project path '{project_path}' does not exist")
        sys.exit(1)
    
    # Create scratch directory
    scratch_dir = project_path / args.scratch_dir
    scratch_dir.mkdir(exist_ok=True)
    
    # Start timing
    start_time = time.time()
    print(f"Script started at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
    print(f"Project path: {project_path}")
    print(f"Scratch directory: {scratch_dir}")
    print(f"Workers: {args.num_workers}")
    print(f"Rows to process: {args.num_rows}")
    
    # Load dataset
    print("\nLoading dataset...")
    dataset_load_start = time.time()
    try:
        dataset = load_dataset("dataset_path")
        data = dataset['train']  # Assuming the data is in 'train' split
    except Exception as e:
        print(f"Error loading dataset: {e}")
        sys.exit(1)
    
    # Limit number of rows
    num_rows_to_process = min(args.num_rows, len(data))
    proofs = [data[i]["formal_ground_truth"] for i in range(num_rows_to_process)]
    dataset_load_time = time.time() - dataset_load_start
    print(f"Dataset loaded in {dataset_load_time:.2f} seconds")
    print(f"Processing {len(proofs)} proofs from dataset")
    
    # Create proof files
    print("\nCreating proof files...")
    file_creation_start = time.time()
    files = []
    for i, code in enumerate(proofs, 1):
        proof_file = scratch_dir / f"P{i:04d}.lean"
        proof_file.write_text(code, encoding="utf-8")
        # Store relative path from project directory
        files.append(proof_file.relative_to(project_path).as_posix())
    
    file_creation_time = time.time() - file_creation_start
    print(f"File creation completed in {file_creation_time:.2f} seconds")
    print(f"Created {len(files)} proof files")
    
    # Compile proofs using LeanClient
    print("\nCompiling proofs...")
    compilation_start = time.time()
    
    try:
        with LeanClientPool(
            str(project_path), 
            num_workers=args.num_workers, 
            max_opened_files=1, 
            initial_build=False
        ) as pool:
            results = []
            for result in pool.map(compiles_ok, files, batch_size=1, verbose=True):
                results.append(result)
    except Exception as e:
        print(f"Error during compilation: {e}")
        sys.exit(1)
    
    compilation_time = time.time() - compilation_start
    print(f"\nCompilation completed in {compilation_time:.2f} seconds")
    
    # Process results
    print("Processing results...")
    processing_start = time.time()
    
    valid_count = 0
    invalid_files = []
    
    for file_path, is_valid in results:
        if is_valid:
            valid_count += 1
        else:
            invalid_files.append(file_path)
    
    processing_time = time.time() - processing_start
    print(f"Result processing completed in {processing_time:.2f} seconds")
    
    # Calculate total elapsed time and statistics
    total_time = time.time() - start_time
    total_proofs = len(results)
    acceptance_rate = (valid_count / total_proofs) * 100 if total_proofs > 0 else 0
    
    # Clean up temporary files
    cleanup_start = time.time()
    for proof_file in scratch_dir.glob("P*.lean"):
        proof_file.unlink()
    cleanup_time = time.time() - cleanup_start
    
    # Final report
    print("\n" + "="*60)
    print("VALIDATION RESULTS")
    print("="*60)
    print(f"Total proofs processed: {total_proofs}")
    print(f"Successful compilations: {valid_count}")
    print(f"Failed compilations: {total_proofs - valid_count}")
    print(f"Overall acceptance rate: {acceptance_rate:.2f}%")
    
    print("\n" + "="*60)
    print("TIMING BREAKDOWN")
    print("="*60)
    print(f"Dataset loading time: {dataset_load_time:.2f} seconds")
    print(f"File creation time: {file_creation_time:.2f} seconds") 
    print(f"Compilation time: {compilation_time:.2f} seconds")
    print(f"Result processing time: {processing_time:.2f} seconds")
    print(f"File cleanup time: {cleanup_time:.2f} seconds")
    print(f"Total elapsed time: {total_time:.2f} seconds")
    print(f"Average time per proof: {total_time/total_proofs:.3f} seconds")
    
    print(f"\nScript completed at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
    
    # Show some failed files for debugging (if any)
    if invalid_files:
        print(f"\nFirst 10 failed files: {invalid_files[:10]}")
        if len(invalid_files) > 10:
            print(f"... and {len(invalid_files) - 10} more failed files")
    
    print("="*60)


if __name__ == "__main__":
    main()
