import pandas as pd
import os
import numpy as np

def generate_latex_table(csv_file, output_file):
    """
    Generate a LaTeX table in booktabs format with a top caption from the prunable datasets CSV.
    Represents NaN values with dashes in the table.
    
    Args:
        csv_file: Path to prunable_datasets.csv
        output_file: Path to the output LaTeX file
    """
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    # Ensure all required columns exist
    required_columns = ['dataset', 'unpruned_tree_size', 'initial_errors', 
                        'prunable_k_adj_1', 'prunable_k_adj_2', 
                        'prunable_k_exch_1', 'prunable_k_exch_2']
    
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        print(f"Error: Missing columns: {', '.join(missing_columns)}")
        return
    
    # Sort by dataset name
    df = df.sort_values('dataset')
    
    # Generate the LaTeX table
    latex_content = []
    
    # Start the LaTeX table
    latex_content.append("% Generated by eval-prune-pruned-trees-convert-to-latex.py")
    latex_content.append("\\begin{table}")
    latex_content.append("\\centering")
    latex_content.append("\\caption{Number of nodes for decision trees that can be pruned without increasing errors after local search. Dashes indicate timeouts.}")
    latex_content.append("\\label{tab:prunable_datasets}")
    # latex_content.append("\\begin{tabular}{lrrrrrr}")
    latex_content.append("\\begin{tabular}{lcccccc}")
    latex_content.append("\\toprule")
    
    # Table header
    # latex_content.append("Dataset & Tree Size & Initial Errors & Prunable $\\kadj = 1$ & Prunable $\\kadj = 2$ & Prunable $\\kexch = 1$ & Prunable $\\kexch = 2$ \\\\")
    latex_content.append("Dataset & Tree Size & Initial Errors & $\\kadj = 1$ & $\\kadj = 2$ & $\\kexch = 1$ & $\\kexch = 2$ \\\\")
    latex_content.append("\\midrule")
    
    # Table content
    for _, row in df.iterrows():
        # Handle potential NaN values
        tree_size = int(row['unpruned_tree_size']) if pd.notna(row['unpruned_tree_size']) else "-"
        initial_errors = int(row['initial_errors']) if pd.notna(row['initial_errors']) else "-"
        
        # For prunable columns, use dash for NaN or empty string
        k_adj_1 = int(row['prunable_k_adj_1']) if pd.notna(row['prunable_k_adj_1']) and row['prunable_k_adj_1'] != '' else "--"
        k_adj_2 = int(row['prunable_k_adj_2']) if pd.notna(row['prunable_k_adj_2']) and row['prunable_k_adj_2'] != '' else "--"
        k_exch_1 = int(row['prunable_k_exch_1']) if pd.notna(row['prunable_k_exch_1']) and row['prunable_k_exch_1'] != '' else "--"
        k_exch_2 = int(row['prunable_k_exch_2']) if pd.notna(row['prunable_k_exch_2']) and row['prunable_k_exch_2'] != '' else "--"
        
        # Format the dataset name (escape underscores for LaTeX)
        dataset_name = row['dataset'].replace('_', '\\_')
        
        # Add the row to the table
        latex_content.append(f"{dataset_name} & {tree_size} & {initial_errors} & {k_adj_1} & {k_adj_2} & {k_exch_1} & {k_exch_2} \\\\")
    
    # End the LaTeX table
    latex_content.append("\\bottomrule")
    latex_content.append("\\end{tabular}")
    latex_content.append("\\end{table}")
    
    # Create the output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Write the LaTeX table to the output file
    with open(output_file, 'w') as f:
        f.write('\n'.join(latex_content))
    
    print(f"LaTeX table generated: {output_file}")

if __name__ == "__main__":
    csv_file = "results/prunable_datasets.csv"
    output_file = "results/prunable_datasets.tex"
    
    if not os.path.exists(csv_file):
        print(f"Error: {csv_file} does not exist!")
    else:
        generate_latex_table(csv_file, output_file)
