#!/usr/bin/env python3
import os
import csv
import time
import subprocess
import shutil
import tempfile
import fcntl
from glob import glob
from multiprocessing import Pool
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# ============ CONFIGURATION - MODIFY THESE ============
# Data flow: $PAPER_PDF_DIR from main.sh (Step 1 output) -> input here
PDF_DIR = '<YOUR_DATA_ROOT>/paper_download/pubmed_papers/'
# Data flow: $PAPER_MARKDOWN_DIR from main.sh (Step 2 output)
MD_DIR = '<YOUR_DATA_ROOT>/paper_markdown/pubmed/'
# Status file for tracking progress
STATUS_FILE = '<YOUR_DATA_ROOT>/paper_extraction_config/processing_status.csv'

# MinerU executable path - use dedicated mineru environment
# Check dedicated environments first, then fallback to PATH
# MODIFY THIS: Add your MinerU environment paths
fallback_paths = [
    '<YOUR_MINERU_ENV>/bin/mineru',  # 🔴 MODIFY: Server env path (no flash-attn conflict)
    '<YOUR_MINERU_ENV_ALT>/bin/mineru',  # 🔴 MODIFY: Local env path
]
MINERU_PATH = None
for path in fallback_paths:
    if os.path.exists(path):
        MINERU_PATH = path
        break
if not MINERU_PATH:
    # Fallback to PATH lookup
    MINERU_PATH = shutil.which('mineru')
if not MINERU_PATH:
    logging.error("MinerU executable not found. Please install it: pip install -U 'mineru[all]'")
    exit(1)

NUM_GPUS = 8
BATCH_SIZE = 10  # Optimal: balances model loading overhead vs batch wait time
TIMEOUT_PER_PDF = 180  # Timeout per PDF in seconds (3 min covers 99%+ of PDFs)
MAX_FAILS = 1  # Mark as 'skip' after this many failures

def load_status():
    """Load processing status from CSV file"""
    status = {}
    if os.path.exists(STATUS_FILE):
        try:
            with open(STATUS_FILE, newline='') as f:
                reader = csv.reader(f)
                for row in reader:
                    if len(row) >= 2:
                        status[row[0]] = row[1]
        except Exception as e:
            logging.error(f"Error loading status file: {e}")
    return status

def save_status(status):
    """Save processing status atomically (write to temp file, then rename)"""
    # Ensure directory exists
    status_dir = os.path.dirname(STATUS_FILE)
    if status_dir and not os.path.exists(status_dir):
        os.makedirs(status_dir, exist_ok=True)
    
    temp_file = STATUS_FILE + '.tmp'
    try:
        with open(temp_file, 'w', newline='') as f:
            writer = csv.writer(f)
            for k, v in status.items():
                writer.writerow([k, v])
        # Atomic rename
        os.replace(temp_file, STATUS_FILE)
    except Exception as e:
        logging.error(f"Error saving status file: {e}")
        # Clean up temp file if it exists
        if os.path.exists(temp_file):
            os.remove(temp_file)

def check_output_exists(basename):
    """Check if output already exists in MD_DIR (for resume support)"""
    name_without_ext = os.path.splitext(basename)[0]
    pdf_output_dir = os.path.join(MD_DIR, name_without_ext)
    
    if os.path.isdir(pdf_output_dir):
        # Look for .md files in the output directory (including subdirs like 'auto')
        for root, dirs, files in os.walk(pdf_output_dir):
            for f in files:
                if f.endswith('.md'):
                    return True
    return False

def check_pdf_output(pdf_file, output_dir):
    """Check if PDF was successfully converted by looking for output files"""
    basename = os.path.basename(pdf_file)
    name_without_ext = os.path.splitext(basename)[0]
    
    # MinerU creates a directory with the PDF name containing the markdown
    pdf_output_dir = os.path.join(output_dir, name_without_ext)
    
    # Check if output directory exists and contains markdown file
    if os.path.isdir(pdf_output_dir):
        # Look for .md files in the output directory (including subdirs like 'auto')
        for root, dirs, files in os.walk(pdf_output_dir):
            for f in files:
                if f.endswith('.md'):
                    return True
    return False

def process_batch(args):
    """Process a batch of PDFs on a single GPU using directory mode
    
    MinerU processes directories more efficiently (single model load).
    This function:
    1. Creates a temp directory with symlinks to the batch PDFs
    2. Runs mineru on the temp directory
    3. Moves outputs to the final MD_DIR
    4. Cleans up temp directory
    
    Args:
        args: tuple of (batch_pdf_files, gpu_id)
    
    Returns:
        list of (pdf_file, status) tuples
    """
    batch_pdf_files, gpu_id = args
    results = []
    
    # Filter out missing files and files with existing output
    existing_files = []
    for pdf_file in batch_pdf_files:
        basename = os.path.basename(pdf_file)
        if not os.path.exists(pdf_file):
            results.append((pdf_file, 'missing'))
        elif check_output_exists(basename):
            # Output already exists (maybe from interrupted run), mark as done
            results.append((pdf_file, 'done'))
            logging.info(f"[GPU {gpu_id}] {basename}: already exists, skipping")
        else:
            existing_files.append(pdf_file)
    
    if not existing_files:
        return results
    
    # Create temporary directory for this batch
    temp_dir = tempfile.mkdtemp(prefix=f'mineru_batch_gpu{gpu_id}_')
    temp_output = tempfile.mkdtemp(prefix=f'mineru_output_gpu{gpu_id}_')
    
    try:
        # Create symlinks to PDFs in temp directory
        for pdf_file in existing_files:
            basename = os.path.basename(pdf_file)
            link_path = os.path.join(temp_dir, basename)
            try:
                os.symlink(pdf_file, link_path)
            except OSError as e:
                logging.warning(f"[GPU {gpu_id}] Failed to create symlink for {basename}: {e}")
                results.append((pdf_file, 'error'))
                existing_files.remove(pdf_file)
        
        if not existing_files:
            return results
        
        # Build mineru command to process the temp directory
        cmd = [MINERU_PATH, '-p', temp_dir, '-o', temp_output]
        
        # Set up environment (inherits proxy settings from parent process)
        env = os.environ.copy()
        env['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
        # Remove LD_LIBRARY_PATH to avoid cudnn conflicts
        env.pop('LD_LIBRARY_PATH', None)
        
        # Calculate timeout based on batch size
        timeout = TIMEOUT_PER_PDF * len(existing_files)
        
        logging.info(f"[GPU {gpu_id}] Processing batch of {len(existing_files)} PDFs (timeout: {timeout}s)")
        
        # Run mineru
        result = subprocess.run(
            cmd,
            env=env,
            capture_output=True,
            timeout=timeout
        )
        
        # Log mineru output for debugging
        if result.stderr:
            stderr_text = result.stderr.decode('utf-8', errors='replace')[:2000]
            if stderr_text.strip():
                logging.warning(f"[GPU {gpu_id}] stderr: {stderr_text}")
        if result.returncode != 0:
            logging.error(f"[GPU {gpu_id}] mineru returned code {result.returncode}")
        
        # Check what mineru actually produced
        if os.path.exists(temp_output):
            output_contents = os.listdir(temp_output)
            logging.info(f"[GPU {gpu_id}] Output dir contents: {output_contents[:5]}{'...' if len(output_contents) > 5 else ''}")
        else:
            logging.warning(f"[GPU {gpu_id}] Output dir does not exist: {temp_output}")
        
        # Check each file's output and move to final destination
        for pdf_file in existing_files:
            basename = os.path.basename(pdf_file)
            name_without_ext = os.path.splitext(basename)[0]
            
            if check_pdf_output(pdf_file, temp_output):
                # Move output to final MD_DIR
                src_dir = os.path.join(temp_output, name_without_ext)
                dst_dir = os.path.join(MD_DIR, name_without_ext)
                
                try:
                    # Remove existing output if any
                    if os.path.exists(dst_dir):
                        shutil.rmtree(dst_dir)
                    shutil.move(src_dir, dst_dir)
                    results.append((pdf_file, 'done'))
                    logging.info(f"[GPU {gpu_id}] {basename}: done")
                except Exception as e:
                    logging.error(f"[GPU {gpu_id}] Failed to move output for {basename}: {e}")
                    results.append((pdf_file, 'error'))
            else:
                results.append((pdf_file, 'failed'))
                logging.warning(f"[GPU {gpu_id}] {basename}: failed (no output)")
                
    except subprocess.TimeoutExpired:
        logging.warning(f"[GPU {gpu_id}] Batch timeout after {timeout}s")
        # Check which files succeeded before timeout
        for pdf_file in existing_files:
            basename = os.path.basename(pdf_file)
            name_without_ext = os.path.splitext(basename)[0]
            
            if check_pdf_output(pdf_file, temp_output):
                # Move successful output
                src_dir = os.path.join(temp_output, name_without_ext)
                dst_dir = os.path.join(MD_DIR, name_without_ext)
                try:
                    if os.path.exists(dst_dir):
                        shutil.rmtree(dst_dir)
                    shutil.move(src_dir, dst_dir)
                    results.append((pdf_file, 'done'))
                    logging.info(f"[GPU {gpu_id}] {basename}: done (before timeout)")
                except Exception as e:
                    results.append((pdf_file, 'timeout'))
            else:
                results.append((pdf_file, 'timeout'))
                logging.warning(f"[GPU {gpu_id}] {basename}: timeout")
                
    except Exception as e:
        logging.error(f"[GPU {gpu_id}] Batch error: {e}")
        # Check which files succeeded before error
        for pdf_file in existing_files:
            basename = os.path.basename(pdf_file)
            name_without_ext = os.path.splitext(basename)[0]
            
            if check_pdf_output(pdf_file, temp_output):
                src_dir = os.path.join(temp_output, name_without_ext)
                dst_dir = os.path.join(MD_DIR, name_without_ext)
                try:
                    if os.path.exists(dst_dir):
                        shutil.rmtree(dst_dir)
                    shutil.move(src_dir, dst_dir)
                    results.append((pdf_file, 'done'))
                except:
                    results.append((pdf_file, 'error'))
            else:
                results.append((pdf_file, 'error'))
    
    finally:
        # Clean up temp directories
        try:
            shutil.rmtree(temp_dir, ignore_errors=True)
            shutil.rmtree(temp_output, ignore_errors=True)
        except:
            pass
    
    return results

def distribute_to_gpus(pdf_list, num_gpus, batch_size):
    """Distribute PDFs into batches for each GPU
    
    Returns list of (batch_files, gpu_id) tuples
    """
    batches = []
    
    # Split into batches
    for i in range(0, len(pdf_list), batch_size):
        batch = pdf_list[i:i + batch_size]
        gpu_id = (i // batch_size) % num_gpus
        batches.append((batch, gpu_id))
    
    return batches

def main():
    print("=== PDF to Markdown Batch Processor Starting ===")
    print(f"PDF Directory: {PDF_DIR}")
    print(f"Output Directory: {MD_DIR}")
    print(f"Status File: {STATUS_FILE}")
    print(f"MinerU Path: {MINERU_PATH}")
    print(f"Settings: {NUM_GPUS} GPUs, {BATCH_SIZE} PDFs/batch, {TIMEOUT_PER_PDF}s/PDF timeout")
    print(f"         Skip after {MAX_FAILS} failures")
    print("\nBATCH MODE: Each GPU processes a directory of PDFs")
    print("            Model loads once per batch for better efficiency.")
    print("\nRESUME SUPPORT: Yes - can stop and restart anytime")
    print("                Status saved after each batch")
    print("                Also checks MD_DIR for existing outputs")
    print("=" * 60)
    
    # Ensure output directory exists
    os.makedirs(MD_DIR, exist_ok=True)
    os.makedirs(os.path.dirname(STATUS_FILE), exist_ok=True)
    
    while True:
        print(f"\n[{time.strftime('%Y-%m-%d %H:%M:%S')}] Starting new processing cycle...")
        
        # Load current status
        print("Loading status file...")
        status = load_status()
        print(f"Loaded {len(status)} entries from status file")
        
        # Find PDFs to process
        print("Scanning PDF directory...")
        all_pdfs = glob(os.path.join(PDF_DIR, '*.pdf'))
        print(f"Found {len(all_pdfs)} total PDF files")
        
        print("Filtering PDFs based on status...")
        to_process = []
        skipped_count = 0
        done_count = 0
        recovered_count = 0
        
        for pdf_path in all_pdfs:
            basename = os.path.basename(pdf_path)
            current = status.get(basename, 'new')
            
            # Skip if done
            if current == 'done':
                done_count += 1
                continue
            
            # Check if output exists but not marked as done (recovery)
            if current != 'done' and check_output_exists(basename):
                status[basename] = 'done'
                done_count += 1
                recovered_count += 1
                continue
            
            # Skip if permanently failed
            if current == 'skip':
                skipped_count += 1
                continue
            
            # Count failures
            if current.startswith(('failed', 'timeout', 'error')):
                # Get fail count (format: "failed_1", "failed_2", etc.)
                if '_' in current:
                    try:
                        count = int(current.split('_')[1])
                        if count >= MAX_FAILS:
                            status[basename] = 'skip'
                            logging.info(f"Auto-skipping {basename} after {count} failures")
                            skipped_count += 1
                            continue
                    except:
                        pass
            
            to_process.append(pdf_path)
        
        # Save any recovered status
        if recovered_count > 0:
            save_status(status)
            print(f"Recovered {recovered_count} files from existing outputs")
        
        print(f"Filter results: {done_count} done, {skipped_count} skipped, {len(to_process)} to process")
        
        if not to_process:
            print("No PDFs to process. Sleeping for 60 seconds...")
            logging.info("No PDFs to process. Sleeping...")
            save_status(status)
            time.sleep(60)
            continue
        
        print(f"\n>>> Starting to process {len(to_process)} PDFs in BATCH mode...")
        logging.info(f"Processing {len(to_process)} PDFs in batch mode")
        
        # Distribute all PDFs into GPU batches
        gpu_batches = distribute_to_gpus(to_process, NUM_GPUS, BATCH_SIZE)
        total_batches = len(gpu_batches)
        print(f">>> Distributed into {total_batches} batches across {NUM_GPUS} GPUs")
        print(">>> Creating process pool...")
        
        # Create pool and process batches
        with Pool(NUM_GPUS) as pool:
            print(f">>> Pool created with {NUM_GPUS} workers")
            print(">>> Starting batch processing...")
            
            processed_total = 0
            total_files = len(to_process)
            batch_num = 0
            
            # Process batches - each batch goes to its assigned GPU
            for batch_results in pool.imap_unordered(process_batch, gpu_batches):
                batch_num += 1
                batch_done = 0
                batch_failed = 0
                
                for pdf_file, result in batch_results:
                    basename = os.path.basename(pdf_file)
                    processed_total += 1
                    old = status.get(basename, 'new')
                    
                    if result == 'done':
                        status[basename] = 'done'
                        batch_done += 1
                    elif result in ['failed', 'timeout', 'error', 'missing']:
                        batch_failed += 1
                        if result == 'missing':
                            status[basename] = 'missing'
                        else:
                            # Increment failure count
                            if old.startswith(('failed', 'timeout', 'error')) and '_' in old:
                                try:
                                    count = int(old.split('_')[1]) + 1
                                except:
                                    count = 1
                            else:
                                count = 1
                            
                            status[basename] = f"{result}_{count}"
                            
                            # Auto-skip after max failures
                            if count >= MAX_FAILS:
                                status[basename] = 'skip'
                                logging.info(f"Skipping {basename} after {MAX_FAILS} failures")
                    else:
                        status[basename] = result
                
                # Save status after each batch completes
                save_status(status)
                done_so_far = len([v for v in status.values() if v == 'done'])
                print(f"    >>> Batch {batch_num}/{total_batches}: {batch_done} done, {batch_failed} failed | Total: {done_so_far}/{len(all_pdfs)} done")
        
        # Show stats
        stats = {'done': 0, 'skip': 0, 'timeout': 0, 'failed': 0, 'error': 0, 'missing': 0, 'other': 0}
        for v in status.values():
            if v == 'done':
                stats['done'] += 1
            elif v == 'skip':
                stats['skip'] += 1
            elif v == 'missing':
                stats['missing'] += 1
            elif v.startswith('timeout'):
                stats['timeout'] += 1
            elif v.startswith('failed'):
                stats['failed'] += 1
            elif v.startswith('error'):
                stats['error'] += 1
            else:
                stats['other'] += 1
        
        print(f"\n>>> Cycle complete! Stats: Done={stats['done']}, Skip={stats['skip']}, "
              f"Timeout={stats['timeout']}, Failed={stats['failed']}, Error={stats['error']}, Missing={stats['missing']}")
        logging.info(f"Stats - Done: {stats['done']}, Skip: {stats['skip']}, "
                    f"Timeout: {stats['timeout']}, Failed: {stats['failed']}, "
                    f"Error: {stats['error']}, Missing: {stats['missing']}")

if __name__ == '__main__':
    main()
