import os
import csv
import time
import subprocess
from glob import glob
from multiprocessing import Pool

# MODIFY THESE PATHS
PDF_DIR = 'folder_pdf_input/Paper_download/pubmed_papers/'
MD_DIR = 'folder_markdown_output/Paper_markdown/pubmed/'
STATUS_FILE = 'folder_config/paper_extraction_config/processing_status.csv'
NUM_GPUS = 8
CHUNK_SIZE = 1000  # Process 1000 files at a time

def load_status():
    status = {}
    if os.path.exists(STATUS_FILE):
        with open(STATUS_FILE, newline='') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                status[row[0]] = row[1]
    return status

def save_status(status):
    with open(STATUS_FILE, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for k, v in status.items():
            writer.writerow([k, v])

def process_pdf(args):
    pdf_file, gpu_id = args
    md_file = os.path.join(MD_DIR, os.path.splitext(os.path.basename(pdf_file))[0] + '.md')
    cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} mineru -p "{pdf_file}" -o "{MD_DIR}"'
    try:
        subprocess.run(cmd, shell=True, check=True, timeout=300)
        return pdf_file, 'done'
    except Exception as e:
        return pdf_file, f'error:{e}'

def chunks(lst, chunk_size):
    """Yield successive chunk_size chunks from lst."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]

def main():
    while True:
        status = load_status()
        pdf_files = sorted(glob(os.path.join(PDF_DIR, '*.pdf')))
        to_process = [f for f in pdf_files if status.get(os.path.basename(f), 'pending') != 'done']
        
        if not to_process:
            print('No new PDFs to process. Sleeping...')
            time.sleep(60)
            continue
        
        print(f"Found {len(to_process)} PDFs to process")
        
        # Process in chunks to avoid memory issues
        with Pool(NUM_GPUS) as pool:
            for chunk_num, pdf_chunk in enumerate(chunks(to_process, CHUNK_SIZE)):
                print(f"Processing chunk {chunk_num + 1}/{len(to_process) // CHUNK_SIZE + 1} "
                      f"({len(pdf_chunk)} files)")
                
                # Create tasks with GPU assignments for this chunk
                tasks = [(pdf, i % NUM_GPUS) for i, pdf in enumerate(pdf_chunk)]
                
                # Process this chunk with continuous processing
                completed_in_chunk = 0
                for pdf_file, result_status in pool.imap_unordered(process_pdf, tasks):
                    status[os.path.basename(pdf_file)] = result_status
                    completed_in_chunk += 1
                    
                    # Save status periodically (every 10 completions) and at chunk end
                    if completed_in_chunk % 10 == 0 or completed_in_chunk == len(pdf_chunk):
                        save_status(status)
                    
                    print(f"Completed ({completed_in_chunk}/{len(pdf_chunk)}): "
                          f"{os.path.basename(pdf_file)} -> {result_status}")
                
                print(f"Chunk {chunk_num + 1} complete ({len(pdf_chunk)} files)")
        
        print(f"All processing complete. Total files processed: {len(to_process)}")

if __name__ == '__main__':
    main()