import os
import glob
import jax
import logging
import pickle
import pandas as pd
from mol3D_benchmark import run_benchmarks, plot_ot_vs_angle, process_results, generate_latex_table, aggregate_results  

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get the directory of the current script
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_FOLDER = os.path.normpath(os.path.join(SCRIPT_DIR, "../../data/rotmol3d/"))

def get_available_gpus():
    """Get list of available GPU devices."""
    try:
        devices = jax.devices('gpu')
        return list(range(len(devices)))
    except:
        logger.warning("No GPUs found, falling back to CPU")
        return [0]  # Use CPU

def run_benchmark_for_file(filename, gpu_id=0):
    """Run benchmark for a single file using specified GPU."""
    # Check if results file already exists
    results_file = filename.replace(".npy", "_results.pickle")
    if os.path.exists(results_file):
        logger.info(f"Results file {results_file} already exists, plotting results...")
        # Load and process results
        results = pickle.load(open(results_file, "rb"))
        N = 12 if "EMDB14621" in filename else 16
        df_joined = process_results(results, N=N)
        # Plot OT value vs angle
        plot_ot_vs_angle(df_joined, filename, exact=True)
        return df_joined

    # Configure JAX to use the specified GPU
    jax.config.update('jax_platform_name', 'gpu')
    jax.config.update('jax_default_device', jax.devices('gpu')[gpu_id])
    
    logger.info(f"Running benchmark for {filename} on GPU {gpu_id}")
    run_benchmarks(filename)
    return None

def main():
    # Find all matching files
    pattern = os.path.join(DATA_FOLDER, "*_mask_radius=128_downscale_factor=16_n_angles=18.npy")
    files = glob.glob(pattern)
    
    if not files:
        logger.error(f"No matching files found in {DATA_FOLDER}")
        return
    
    logger.info(f"Found {len(files)} files to process")
    
    # Get available GPUs
    gpus = get_available_gpus()
    logger.info(f"Available devices: {gpus}")
    
    # Process all files and collect results
    all_results = []
    for i, file in enumerate(files):
        df = run_benchmark_for_file(file)
        if df is not None:
            # Add molecule identifier
            molecule = os.path.basename(file).split('_')[1]
            df['molecule'] = molecule
            all_results.append(df)
    
    if all_results:
        # Concatenate all results
        df_time = pd.concat(
            [pd.concat({df_result['molecule'].iloc[0]: aggregate_results(df_result, field='relative_time')}, names=['molecule'])
                 for df_result in all_results]
        )
        print(df_time)
        # Generate tables for combined results
        latex_time = generate_latex_table(df_time, field='relative_time')
        # df_error, latex_error = generate_latex_table(combined_df, field='relative_error')
        # Save LaTeX tables
        with open(os.path.join(DATA_FOLDER, "combined_time_table.tex"), 'w') as f:
            f.write(latex_time)
        # with open(os.path.join(DATA_FOLDER, "combined_error_table.tex"), 'w') as f:
        #     f.write(latex_error)
        
        logger.info("Combined LaTeX tables have been written to 'combined_time_table.tex'")

if __name__ == "__main__":
    main()