import pandas as pd
import os
import re

def analyze_tree_improvements():
    """
    Analyzes whether increasing k_adj or k_exch (to values 1 and 2) while keeping k_repl
    at the value from the heuristic pruning can reduce error rates for unpruned trees.
    
    This script:
    1. Filters for unpruned trees in consolidated_results.csv
    2. Matches each instancename to a corresponding dataset in heuristicpruned.csv
    3. For each match, performs a consistency check that for k_repl = y (heuristic pruned),
       we have at least x errors (heuristic errors)
    4. Checks if increasing k_adj or k_exch to 1 or 2 reduces errors
    5. Records instances where errors decrease and by how much
    6. Outputs all results to CSV files in the results/ directory
    """
    # Create results directory if it doesn't exist
    os.makedirs('results', exist_ok=True)
    
    # Load the CSV files
    print("Loading data...")
    consolidated_df = pd.read_csv('results/consolidated_results.csv')
    heuristic_df = pd.read_csv('results/heuristic-pruned.csv')
    
    # Filter to only unpruned trees
    unpruned_df = consolidated_df[consolidated_df['pruned/unpruned tree'] == 'unpruned']
    print(f"Found {len(unpruned_df)} rows for unpruned trees")
    
    # Get unique instances
    instances = unpruned_df['instancename'].unique()
    print(f"Found {len(instances)} unique instances")
    
    # Create a mapping dictionary for quick lookup
    heuristic_map = {row['Dataset']: (row['Heuristic pruned'], row['Heuristic errors']) 
                     for _, row in heuristic_df.iterrows()}
    
    # Initialize results lists
    consistency_errors = []
    all_results = []  # Track ALL results, not just improvements
    
    # Process each instance
    for instance in instances:
        # Extract dataset name (basename of instance)
        dataset = os.path.basename(instance).split('.')[0]
        
        # Skip if dataset not found in heuristic map
        if dataset not in heuristic_map:
            print(f"Warning: No matching dataset for instance '{instance}' (extracted: '{dataset}')")
            continue
        
        # Get heuristic values
        y, x = heuristic_map[dataset]  # (nodes pruned, errors)
        
        # Get baseline configuration (k_repl=y, k_adj=k_exch=k_rais=0)
        baseline_rows = unpruned_df[(unpruned_df['instancename'] == instance) & 
                                  (unpruned_df['k_repl'] == y) & 
                                  (unpruned_df['k_adj'] == 0) &
                                  (unpruned_df['k_exch'] == 0) &
                                  (unpruned_df['k_rais'] == 0)]
        
        if baseline_rows.empty:
            print(f"Warning: No baseline configuration found for {instance} with k_repl={y}")
            continue
        
        baseline_errors = baseline_rows['min_errors'].iloc[0]
        
        # Consistency check
        if baseline_errors < x:
            consistency_errors.append({
                'instance': instance,
                'dataset': dataset,
                'heuristic_errors': x,
                'baseline_errors': baseline_errors,
                'difference': x - baseline_errors
            })
            print(f"Consistency error: {instance} has baseline_errors={baseline_errors} < heuristic_errors={x}")
        
        # Check k_adj configurations
        for k_adj_val in [1, 2]:
            # Create a result entry structure
            result_entry = {
                'instance': instance,
                'dataset': dataset,
                'parameter': 'k_adj',
                'value': k_adj_val,
                'heuristic_pruned': y,
                'heuristic_errors': x,
                'baseline_errors': baseline_errors,
                'result_status': 'missing',  # Default to missing
                'improved_errors': None,
                'error_reduction': None,
                'percent_reduction': None,
                'is_improvement': False
            }
            
            # Query for this configuration
            k_adj_rows = unpruned_df[(unpruned_df['instancename'] == instance) & 
                                   (unpruned_df['k_repl'] == y) & 
                                   (unpruned_df['k_adj'] == k_adj_val) &
                                   (unpruned_df['k_exch'] == 0) &
                                   (unpruned_df['k_rais'] == 0)]
            
            if not k_adj_rows.empty:
                k_adj_errors = k_adj_rows['min_errors'].iloc[0]
                result_entry['improved_errors'] = k_adj_errors
                result_entry['error_reduction'] = baseline_errors - k_adj_errors
                result_entry['percent_reduction'] = ((baseline_errors - k_adj_errors) / baseline_errors) * 100 if baseline_errors > 0 else 0
                
                if k_adj_errors < baseline_errors:
                    result_entry['result_status'] = 'improved'
                    result_entry['is_improvement'] = True
                else:
                    result_entry['result_status'] = 'not_improved'
            
            all_results.append(result_entry)
        
        # Check k_exch configurations
        for k_exch_val in [1, 2]:
            # Create a result entry structure
            result_entry = {
                'instance': instance,
                'dataset': dataset,
                'parameter': 'k_exch',
                'value': k_exch_val,
                'heuristic_pruned': y,
                'heuristic_errors': x,
                'baseline_errors': baseline_errors,
                'result_status': 'missing',  # Default to missing
                'improved_errors': None,
                'error_reduction': None,
                'percent_reduction': None,
                'is_improvement': False
            }
            
            # Query for this configuration
            k_exch_rows = unpruned_df[(unpruned_df['instancename'] == instance) & 
                                    (unpruned_df['k_repl'] == y) & 
                                    (unpruned_df['k_adj'] == 0) &
                                    (unpruned_df['k_exch'] == k_exch_val) &
                                    (unpruned_df['k_rais'] == 0)]
            
            if not k_exch_rows.empty:
                k_exch_errors = k_exch_rows['min_errors'].iloc[0]
                result_entry['improved_errors'] = k_exch_errors
                result_entry['error_reduction'] = baseline_errors - k_exch_errors
                result_entry['percent_reduction'] = ((baseline_errors - k_exch_errors) / baseline_errors) * 100 if baseline_errors > 0 else 0
                
                if k_exch_errors < baseline_errors:
                    result_entry['result_status'] = 'improved'
                    result_entry['is_improvement'] = True
                else:
                    result_entry['result_status'] = 'not_improved'
            
            all_results.append(result_entry)
    
    # Convert all results to DataFrame
    all_results_df = pd.DataFrame(all_results)
    
    # Extract just the improvements
    improvements_df = all_results_df[all_results_df['is_improvement'] == True].copy()
    
    # Save results
    if consistency_errors:
        consistency_df = pd.DataFrame(consistency_errors)
        consistency_df.to_csv('results/prune_unpruned_consistency_errors.csv', index=False)
        print(f"Saved {len(consistency_errors)} consistency errors to 'results/prune_unpruned_consistency_errors.csv'")
    
    # Save all results
    all_results_df.to_csv('results/prune_unpruned_all_results.csv', index=False)
    print(f"Saved {len(all_results_df)} results to 'results/prune_unpruned_all_results.csv'")
    
    # Save improvements
    improvements_df.to_csv('results/prune_unpruned_all_improvements.csv', index=False)
    print(f"Saved {len(improvements_df)} improvements to 'results/prune_unpruned_all_improvements.csv'")
    
    for param in ['k_adj', 'k_exch']:
        param_improvements = improvements_df[improvements_df['parameter'] == param]

        if not param_improvements.empty:
            # Add preprocessing to keep only the best improvement per instance
            if 'instance' in param_improvements.columns:
                # Find duplicated instances (those that appear more than once)
                duplicate_instances = param_improvements['instance'].duplicated(keep=False)

                if duplicate_instances.any():
                    # Create a dataframe for instances without duplicates
                    non_duplicated = param_improvements[~duplicate_instances]

                    # For duplicated instances, find the one with the best error_reduction
                    duplicated = param_improvements[duplicate_instances]

                    # Sort by instance, error_reduction (descending), and value (ascending)
                    # This ensures that for tied error_reductions, we keep the smallest parameter value
                    best_improvements = (duplicated
                                        .sort_values(['instance', 'error_reduction', 'value'], 
                                                   ascending=[True, False, True])
                                        .drop_duplicates(subset=['instance'], keep='first'))

                    # Combine non-duplicated instances with the best improvements
                    param_improvements = pd.concat([non_duplicated, best_improvements])

                    print(f"Removed {len(duplicated) - len(best_improvements)} duplicate entries with worse improvements.")

            # Continue with the original code
            param_improvements.to_csv(f'results/prune_unpruned_{param}_improvements.csv', index=False)
            print(f"Saved {len(param_improvements)} {param} improvements to 'results/prune_unpruned_{param}_improvements.csv'")

            # Create summary by parameter value
            summary = param_improvements.groupby('value').agg({
                'instance': 'count',
                'error_reduction': ['sum', 'mean', 'median', 'min', 'max'],
                'percent_reduction': ['mean', 'median', 'min', 'max']
            })

            # Flatten multi-level columns
            summary.columns = ['_'.join(col).strip() for col in summary.columns.values]
            summary = summary.reset_index()

            summary.to_csv(f'results/prune_unpruned_{param}_summary.csv', index=False)
            print(f"Saved {param} summary to 'results/prune_unpruned_{param}_summary.csv'")

        # Create separate files for parameter improvements
    # for param in ['k_adj', 'k_exch']:
    #     param_improvements = improvements_df[improvements_df['parameter'] == param]
        
    #     if not param_improvements.empty:
    #         param_improvements.to_csv(f'results/prune_unpruned_{param}_improvements.csv', index=False)
    #         print(f"Saved {len(param_improvements)} {param} improvements to 'results/prune_unpruned_{param}_improvements.csv'")
            
    #         # Create summary by parameter value
    #         summary = param_improvements.groupby('value').agg({
    #             'instance': 'count',
    #             'error_reduction': ['sum', 'mean', 'median', 'min', 'max'],
    #             'percent_reduction': ['mean', 'median', 'min', 'max']
    #         })
            
    #         # Flatten multi-level columns
    #         summary.columns = ['_'.join(col).strip() for col in summary.columns.values]
    #         summary = summary.reset_index()
            
    #         summary.to_csv(f'results/prune_unpruned_{param}_summary.csv', index=False)
    #         print(f"Saved {param} summary to 'results/prune_unpruned_{param}_summary.csv'")
    
    print("Analysis complete!")
    return True

if __name__ == "__main__":
    analyze_tree_improvements()
