import pandas as pd
import os
import csv
import numpy as np

def analyze_pruned_trees(csv_file):
    """
    Analyze pruned trees to determine if they could be further pruned without increasing errors
    after local search operations.
    
    Outputs:
    1. Summary statistics as LaTeX commands
    2. CSV with prunable node information for each dataset
    """
    # Load the CSV file
    df = pd.read_csv(csv_file)
    
    # Create results directory if it doesn't exist
    os.makedirs('results', exist_ok=True)
    
    # Filter to only look at pruned trees
    df_pruned = df[df['pruned/unpruned tree'] == 'pruned']
    
    if df_pruned.empty:
        print("No pruned trees found in the data.")
        return
    
    # Dictionary to store results for CSV output
    csv_results = {}
    
    # Analyze local search operation (k_adj or k_exch)
    def analyze_local_search(param, other_param):
        print(f"\nDatasets where increasing {param} allows further pruning (k_repl > 0):")
        
        # Filter to entries where other_param = 0 and k_rais = 0
        df_filtered = df_pruned[(df_pruned[other_param] == 0) & (df_pruned['k_rais'] == 0)]
        
        if df_filtered.empty:
            print(f"No entries found with {other_param}=0 and k_rais=0.")
            return {}
        
        # Get unique datasets
        datasets = df_filtered['instancename'].unique()
        
        found_datasets = []
        results_by_dataset = {}
        
        for dataset in datasets:
            # Get dataset entries
            dataset_df = df_filtered[df_filtered['instancename'] == dataset]
            
            # Get baseline (where param = 0 and k_repl = 0)
            baseline = dataset_df[(dataset_df[param] == 0) & (dataset_df['k_repl'] == 0)]
            
            if baseline.empty:
                continue
            
            baseline_errors = baseline['min_errors'].values[0]
            baseline_prunes = dataset_df[(dataset_df[param] == 0) & (dataset_df['min_errors'] <= baseline_errors)]
            
            assert baseline_prunes['k_repl'].max() <= 0

            # Initialize dataset in csv_results if not already there
            if dataset not in csv_results:
                csv_results[dataset] = {
                    'initial_errors': baseline_errors,
                    'prunable_k_adj_1': float('nan'),
                    'prunable_k_adj_2': float('nan'),
                    'prunable_k_exch_1': float('nan'),
                    'prunable_k_exch_2': float('nan')
                }
            else:
                csv_results[dataset]['initial_errors'] = baseline_errors
            
            # Track if dataset was found
            dataset_found = False
            dataset_results = {}
            
            # Check which parameter values are available for this dataset
            param_values = sorted(dataset_df[param].unique())
            
            for param_val in param_values:
                # Only process param_val = 1 or 2 for CSV output and skip param_val = 0
                if param_val not in [1, 2] or param_val == 0:
                    continue
                
                # Get entries for this param value
                param_entries = dataset_df[dataset_df[param] == param_val]
                
                if param_entries.empty:
                    # Mark parameter as checked but not applicable (still NaN in output)
                    continue
                
                # Find entries with k_repl > 0 and min_errors <= baseline
                valid_entries = param_entries[
                    (param_entries['k_repl'] > 0) & 
                    (param_entries['min_errors'] <= baseline_errors)
                ]
                
                if not valid_entries.empty:
                    # Find the largest k_repl value
                    max_k_repl = valid_entries['k_repl'].max()
                    min_errors = valid_entries[valid_entries['k_repl'] == max_k_repl]['min_errors'].values[0]
                    
                    # Store result for CSV output
                    if param == 'k_adj':
                        csv_results[dataset][f'prunable_k_adj_{param_val}'] = max_k_repl
                    else:  # param == 'k_exch'
                        csv_results[dataset][f'prunable_k_exch_{param_val}'] = max_k_repl
                    
                    if not dataset_found:
                        print(f"Dataset: {dataset}")
                        print(f"Baseline errors (with {param}=0, k_repl=0): {baseline_errors}")
                        dataset_found = True
                    
                    print(f"  {param}={param_val}, max k_repl={max_k_repl}, min_errors={min_errors}")
                    dataset_results[param_val] = max_k_repl
                else:
                    # Data was available but no valid entries - set to 0 to indicate it was checked
                    if param == 'k_adj':
                        csv_results[dataset][f'prunable_k_adj_{param_val}'] = 0
                    else:  # param == 'k_exch'
                        csv_results[dataset][f'prunable_k_exch_{param_val}'] = 0
            
            if dataset_found:
                found_datasets.append(dataset)
                results_by_dataset[dataset] = dataset_results
        
        if not found_datasets:
            print(f"No datasets found where increasing {param} allows increasing k_repl without increasing errors.")
        
        return results_by_dataset
    
    # Analyze both local search operations
    print("Analyzing whether pruned trees could be further pruned without increasing errors after local search...")
    adj_results = analyze_local_search('k_adj', 'k_exch')
    exch_results = analyze_local_search('k_exch', 'k_adj')
    
    # Summary
    adj_datasets = list(adj_results.keys())
    exch_datasets = list(exch_results.keys())
    
    num_datasets_adj = len(adj_datasets)
    num_datasets_exch = len(exch_datasets)
    
    if adj_datasets or exch_datasets:
        print("\nSummary:")
        print(f"Datasets where increasing k_adj allows further pruning: {num_datasets_adj}")
        print(f"Datasets where increasing k_exch allows further pruning: {num_datasets_exch}")
    else:
        print("\nNo datasets found where further pruning is possible without increasing errors.")
    
    # Calculate average prunable nodes (skipping NaN values)
    avg_prunable_k_adj_1 = np.nanmean([results['prunable_k_adj_1'] for results in csv_results.values() if not np.isnan(results['prunable_k_adj_1']) and results['prunable_k_adj_1'] > 0])
    avg_prunable_k_adj_2 = np.nanmean([results['prunable_k_adj_2'] for results in csv_results.values() if not np.isnan(results['prunable_k_adj_2']) and results['prunable_k_adj_2'] > 0])
    avg_prunable_k_exch_1 = np.nanmean([results['prunable_k_exch_1'] for results in csv_results.values() if not np.isnan(results['prunable_k_exch_1']) and results['prunable_k_exch_1'] > 0])
    avg_prunable_k_exch_2 = np.nanmean([results['prunable_k_exch_2'] for results in csv_results.values() if not np.isnan(results['prunable_k_exch_2']) and results['prunable_k_exch_2'] > 0])
    
    # Count datasets with values > 0 (ignoring NaN)
    count_k_adj_1 = sum(1 for results in csv_results.values() if not np.isnan(results['prunable_k_adj_1']) and results['prunable_k_adj_1'] > 0)
    count_k_adj_2 = sum(1 for results in csv_results.values() if not np.isnan(results['prunable_k_adj_2']) and results['prunable_k_adj_2'] > 0)
    count_k_exch_1 = sum(1 for results in csv_results.values() if not np.isnan(results['prunable_k_exch_1']) and results['prunable_k_exch_1'] > 0)
    count_k_exch_2 = sum(1 for results in csv_results.values() if not np.isnan(results['prunable_k_exch_2']) and results['prunable_k_exch_2'] > 0)
    
    # Write LaTeX commands to file
    with open('results/pruning_stats.tex', 'w') as f:
        f.write(f"\\newcommand{{\\numdatasetskadjprune}}{{{num_datasets_adj}}}\n")
        f.write(f"\\newcommand{{\\numdatasetsexchprune}}{{{num_datasets_exch}}}\n")
        f.write(f"\\newcommand{{\\avgprunablenodeskadj}}{{{avg_prunable_k_adj_1:.2f}}}\n")
        f.write(f"\\newcommand{{\\avgprunablenodeskadjtwo}}{{{avg_prunable_k_adj_2:.2f}}}\n")
        f.write(f"\\newcommand{{\\avgprunablenodeskexch}}{{{avg_prunable_k_exch_1:.2f}}}\n")
        f.write(f"\\newcommand{{\\avgprunablenodeskexchtwo}}{{{avg_prunable_k_exch_2:.2f}}}\n")
        f.write(f"\\newcommand{{\\countkadj}}{{{count_k_adj_1}}}\n")
        f.write(f"\\newcommand{{\\countkadjtwo}}{{{count_k_adj_2}}}\n")
        f.write(f"\\newcommand{{\\countkexch}}{{{count_k_exch_1}}}\n")
        f.write(f"\\newcommand{{\\countkexchtwo}}{{{count_k_exch_2}}}\n")
    
    # Write CSV output
    with open('results/prunable_datasets.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'dataset', 'initial_errors', 
            'prunable_k_adj_1', 'prunable_k_adj_2', 
            'prunable_k_exch_1', 'prunable_k_exch_2'
        ])
        
        for dataset, results in csv_results.items():
            if results['prunable_k_adj_1'] > 0 or results['prunable_k_adj_2'] > 0 or results['prunable_k_exch_1'] > 0 or results['prunable_k_exch_2'] > 0:
                row = [
                    dataset, 
                    results['initial_errors']
                ]

                # Add prunable values, replacing NaN with empty string for CSV
                for col in ['prunable_k_adj_1', 'prunable_k_adj_2', 'prunable_k_exch_1', 'prunable_k_exch_2']:
                    value = results[col]
                    if np.isnan(value):
                        row.append('')  # Empty string for NaN in CSV
                    else:
                        row.append(int(value))

                writer.writerow(row)
    
    print("\nOutputs saved:")
    print("1. LaTeX summary statistics: results/pruning_stats.tex")
    print("2. CSV with prunable information: results/prunable_datasets.csv")

if __name__ == "__main__":
    import sys
    
    # if len(sys.argv) != 2:
    #     print("Usage: python analyze_trees.py consolidated_results.csv")
    # else:
    #     analyze_pruned_trees(sys.argv[1])
    analyze_pruned_trees("results/consolidated_results.csv")
