import pandas as pd
import os

def merge_unpruned_tree_size(prunable_csv, tree_analysis_csv):
    """
    Merge the tree size of unpruned trees from tree_analysis_results.csv into prunable_datasets.csv.
    
    For each dataset X in prunable_datasets.csv, looks for the corresponding tree X_data-unpru
    in tree_analysis_results.csv and adds its tree size.
    """
    # Read the CSV files
    prunable_df = pd.read_csv(prunable_csv)
    tree_analysis_df = pd.read_csv(tree_analysis_csv)
    
    print(f"Loaded {len(prunable_df)} records from {prunable_csv}")
    print(f"Loaded {len(tree_analysis_df)} records from {tree_analysis_csv}")
    
    # Create a mapping from dataset name to unpruned tree size
    tree_size_mapping = {}
    
    for _, row in tree_analysis_df.iterrows():
        tree_name = row['Tree']
        
        # Check if this is an unpruned tree (ends with _data-unpru)
        if tree_name.endswith('_data-unpru'):
            # Extract the dataset name by removing the suffix
            dataset_name = tree_name.replace('_data-unpru', '')
            tree_size_mapping[dataset_name] = row['Tree Size']
    
    # Count matches
    matched_datasets = sum(prunable_df['dataset'].isin(tree_size_mapping.keys()))
    print(f"Matched {matched_datasets}/{len(prunable_df)} datasets with unpruned tree sizes")
    
    # Add the unpruned tree size to the prunable datasets
    prunable_df['unpruned_tree_size'] = prunable_df['dataset'].map(tree_size_mapping)
    
    # Save the updated dataframe
    prunable_df.to_csv(prunable_csv, index=False)
    
    print(f"Updated {prunable_csv} with unpruned tree sizes")

if __name__ == "__main__":
    prunable_csv = "results/prunable_datasets.csv"
    tree_analysis_csv = "results/tree_analysis_results.csv"
    
    if not os.path.exists(prunable_csv):
        print(f"Error: {prunable_csv} does not exist!")
    elif not os.path.exists(tree_analysis_csv):
        print(f"Error: {tree_analysis_csv} does not exist!")
    else:
        merge_unpruned_tree_size(prunable_csv, tree_analysis_csv)
