"""
Analysis script for pruned decision tree error reduction

This script analyzes a CSV dataset of decision tree results to:
1. Find pruned instances where increasing k_exch or k_adj reduces min_errors
2. Create a comparison table of error rates for different operations
3. Output the LaTeX table to the results/ subdirectory
4. Output instance improvement statistics to error_reduction_statistics.tex
"""

import pandas as pd
import os
import pathlib
from typing import List, Dict, Any, Optional, Tuple, Set


def find_instances_with_error_reduction(data: pd.DataFrame, parameter: str) -> Dict[str, List[str]]:
    """
    Find instances where errors decrease when a parameter is increased.
    Also tracks which instances improve with specific parameter values.
    
    Args:
        data: DataFrame containing the dataset
        parameter: Parameter to analyze ('k_exch' or 'k_adj')
        
    Returns:
        Dictionary containing:
            'any_improvement': List of instance names where any increase in parameter reduces errors
            'value_1': List of instances where parameter=1 reduces errors
            'value_2': List of instances where parameter=2 reduces errors
    """
    instances_any_improvement = []
    instances_value_1 = []
    instances_value_2 = []
    
    unique_instances = data['instancename'].unique()
    
    for instance in unique_instances:
        # Filter data for this instance
        instance_data = data[data['instancename'] == instance]
        
        # Filter based on which parameter we're analyzing
        if parameter == 'k_exch':
            filtered_data = instance_data[(instance_data['k_adj'] == 0) & 
                                          (instance_data['k_rais'] == 0) & 
                                          (instance_data['k_repl'] == 0) &
                                          (instance_data['k_exch'] <= 2)]
        elif parameter == 'k_adj':
            filtered_data = instance_data[(instance_data['k_exch'] == 0) & 
                                          (instance_data['k_rais'] == 0) & 
                                          (instance_data['k_repl'] == 0) &
                                          (instance_data['k_adj'] <= 2)]
        else:
            continue

        # if len(filtered_data) != 2:
            # print(filtered_data)

        # Sort by parameter value
        filtered_data = filtered_data.sort_values(by=parameter)
        
        # Check if there's at least one case where errors decrease
        any_improvement = False
        
        if not filtered_data.empty:
            # Get the base errors (parameter = 0)
            base_data = filtered_data[filtered_data[parameter] == 0]
            if not base_data.empty:
                base_errors = base_data['min_errors'].min()
                
                # Check for parameter = 1
                value_1_data = filtered_data[filtered_data[parameter] == 1]
                if not value_1_data.empty:
                    value_1_errors = value_1_data['min_errors'].min()
                    if value_1_errors < base_errors:
                        any_improvement = True
                        instances_value_1.append(instance)
                
                # Check for parameter = 2
                value_2_data = filtered_data[filtered_data[parameter] == 2]
                if not value_2_data.empty:
                    value_2_errors = value_2_data['min_errors'].min()
                    if value_2_errors < base_errors:
                        any_improvement = True
                        instances_value_2.append(instance)
        
        if any_improvement:
            instances_any_improvement.append(instance)
    
    return {
        'any_improvement': instances_any_improvement,
        'value_1': instances_value_1,
        'value_2': instances_value_2
    }


def find_configuration(data: pd.DataFrame, instance: str, 
                        k_exch: int, k_adj: int, k_rais: int, k_repl: int) -> Optional[int]:
    """
    Find a specific configuration for an instance and return its min_errors.
    
    Args:
        data: DataFrame containing the dataset
        instance: Name of the instance
        k_exch, k_adj, k_rais, k_repl: Parameter values to filter by
        
    Returns:
        The minimum error value for the configuration or None if not found
    """
    matches = data[(data['instancename'] == instance) & 
                   (data['k_exch'] == k_exch) & 
                   (data['k_adj'] == k_adj) & 
                   (data['k_rais'] == k_rais) & 
                   (data['k_repl'] == k_repl)]
    
    if not matches.empty:
        return matches['min_errors'].min()
    
    return None


def create_comparison_table(data: pd.DataFrame, instances: List[str]) -> pd.DataFrame:
    """
    Create a comparison table of error rates for different operations.
    
    Args:
        data: DataFrame containing the dataset
        instances: List of instance names to include in the table
        
    Returns:
        DataFrame containing the comparison table
    """
    table_data = []
    
    for instance in instances:
        # Get base configuration (no operations)
        base_errors = find_configuration(data, instance, 0, 0, 0, 0)
        
        if base_errors is None:
            print(f"Warning: No base configuration found for instance {instance}")
            continue
        
        # Get configurations for different operations
        adj1_errors = find_configuration(data, instance, 0, 1, 0, 0)
        adj2_errors = find_configuration(data, instance, 0, 2, 0, 0)
        exch1_errors = find_configuration(data, instance, 1, 0, 0, 0)
        exch2_errors = find_configuration(data, instance, 2, 0, 0, 0)
        
        table_data.append({
            "Dataset": instance,
            "Initial": base_errors if base_errors is not None else "",
            # "1 Adjustment": adj1_errors if adj1_errors is not None else "",
            # "2 Adjustments": adj2_errors if adj2_errors is not None else "",
            # "1 Exchange": exch1_errors if exch1_errors is not None else "",
            # "2 Exchanges": exch2_errors if exch2_errors is not None else ""
            "$\\kadj = 1$": adj1_errors if adj1_errors is not None else "--",
            "$\\kadj = 2$": adj2_errors if adj2_errors is not None else "--",
            "$\\kexch = 1$": exch1_errors if exch1_errors is not None else "--",
            "$\\kexch = 2$": exch2_errors if exch2_errors is not None else "--"
        })

    if len(instances) > 0:
        retdf = pd.DataFrame(table_data).sort_values(by="Dataset")
    else:
        retdf = pd.DataFrame(table_data)

    return retdf


def table_to_latex(table: pd.DataFrame, caption: str = "Error rates for different operations. Dashes indicate timeouts.") -> str:
    """
    Convert a DataFrame to LaTeX table format.
    
    Args:
        table: DataFrame containing the table data
        caption: Caption for the LaTeX table
        
    Returns:
        String containing the LaTeX table
    """
    # Create a clean LaTeX table
    latex = "% auto-generated by eval-pruned-tree-errors.py\n"
    latex += "\\begin{table}[t]\n"
    latex += "\\centering\n"
    latex += f"\\caption{{{caption}}}\n"
    latex += "\\label{tab:error_reduction}\n"
    latex += "\\begin{tabular}{lccccc}\n"
    latex += "\\toprule\n"
    
    # Header row - replace underscores with escaped underscores for LaTeX
    columns = table.columns.tolist()
    header_row = " & ".join([col.replace("_", "\\_") for col in columns])
    # latex += header_row + " \\\\ \\hline\\hline\n"
    latex += header_row + " \\\\ \\midrule \n"
    
    # Data rows
    for _, row in table.iterrows():
        row_latex = " & ".join([str(row[col]) for col in columns])
        latex += row_latex + " \\\\ \n"
    
    latex += "\\bottomrule\n"
    latex += "\\end{tabular}\n"
    latex += "\\end{table}"
    
    return latex


def write_error_reduction_statistics(exch_results: Dict[str, List[str]], 
                                    adj_results: Dict[str, List[str]],
                                    output_path: pathlib.Path) -> None:
    """
    Write error reduction statistics to a LaTeX file.
    
    Args:
        exch_results: Results from find_instances_with_error_reduction for k_exch
        adj_results: Results from find_instances_with_error_reduction for k_adj
        output_path: Path to write the LaTeX file
    """
    # Format instance lists as strings (comma-separated)
    exch_any = ", ".join(sorted(exch_results['any_improvement']))
    exch_value_1 = ", ".join(sorted(exch_results['value_1']))
    exch_value_2 = ", ".join(sorted(exch_results['value_2']))
    
    adj_any = ", ".join(sorted(adj_results['any_improvement']))
    adj_value_1 = ", ".join(sorted(adj_results['value_1']))
    adj_value_2 = ", ".join(sorted(adj_results['value_2']))
    
    # Combined set of instances that improve with either operation
    all_improved = sorted(list(set(exch_results['any_improvement'] + adj_results['any_improvement'])))
    all_improved_str = ", ".join(all_improved)
    
    # Create LaTeX content
    latex_content = "% Error reduction statistics - generated by eval-pruned-tree-errors.py\n\n"
    
    # Exchange operations
    latex_content += "% Exchange Operations\n"
    latex_content += f"\\newcommand{{\\impExchangeAnyInstances}}{{{len(exch_results['any_improvement'])}}}\n"
    latex_content += f"\\newcommand{{\\impExchangeAnyList}}{{{exch_any}}}\n\n"
    
    latex_content += f"\\newcommand{{\\impExchangeOneInstances}}{{{len(exch_results['value_1'])}}}\n"
    latex_content += f"\\newcommand{{\\impExchangeOneList}}{{{exch_value_1}}}\n\n"
    
    latex_content += f"\\newcommand{{\\impExchangeTwoInstances}}{{{len(exch_results['value_2'])}}}\n"
    latex_content += f"\\newcommand{{\\impExchangeTwoList}}{{{exch_value_2}}}\n\n"
    
    # Adjustment operations
    latex_content += "% Adjustment Operations\n"
    latex_content += f"\\newcommand{{\\impAdjustmentAnyInstances}}{{{len(adj_results['any_improvement'])}}}\n"
    latex_content += f"\\newcommand{{\\impAdjustmentAnyList}}{{{adj_any}}}\n\n"
    
    latex_content += f"\\newcommand{{\\impAdjustmentOneInstances}}{{{len(adj_results['value_1'])}}}\n"
    latex_content += f"\\newcommand{{\\impAdjustmentOneList}}{{{adj_value_1}}}\n\n"
    
    latex_content += f"\\newcommand{{\\impAdjustmentTwoInstances}}{{{len(adj_results['value_2'])}}}\n"
    latex_content += f"\\newcommand{{\\impAdjustmentTwoList}}{{{adj_value_2}}}\n\n"
    
    # Combined statistics
    latex_content += "% Combined Statistics\n"
    latex_content += f"\\newcommand{{\\totalImprovedInstances}}{{{len(all_improved)}}}\n"
    latex_content += f"\\newcommand{{\\allImprovedList}}{{{all_improved_str}}}\n"
    
    # Write to file
    with open(output_path, 'w') as f:
        f.write(latex_content)


def analyze_decision_tree_errors(filename: str) -> None:
    """
    Main analysis function for decision tree errors.
    
    Args:
        filename: Path to the CSV file containing the dataset
    """
    # Read the CSV file
    print(f"Reading data from {filename}...")
    data = pd.read_csv(filename)
    
    # Filter only pruned instances
    pruned_data = data[data["pruned/unpruned tree"] == "pruned"]
    print(f"Found {len(pruned_data)} pruned tree entries")
    
    # Find instances where errors decrease
    print("\nFinding instances where errors decrease...")
    exch_results = find_instances_with_error_reduction(pruned_data, 'k_exch')
    adj_results = find_instances_with_error_reduction(pruned_data, 'k_adj')
    
    instances_with_exch_reduction = exch_results['any_improvement']
    instances_with_adj_reduction = adj_results['any_improvement']
    
    # Combine the sets of instances
    all_relevant_instances = list(set(instances_with_exch_reduction + instances_with_adj_reduction))
    
    print(f"\nInstances where increasing k_exch reduces errors ({len(instances_with_exch_reduction)}):")
    print(instances_with_exch_reduction)
    
    print(f"\nInstances where increasing k_adj reduces errors ({len(instances_with_adj_reduction)}):")
    print(instances_with_adj_reduction)
    
    print(f"\nAll instances with error reduction from either operation ({len(all_relevant_instances)}):")
    print(all_relevant_instances)
    
    # Create the comparison table
    print("\nCreating comparison table...")
    table_data = create_comparison_table(pruned_data, all_relevant_instances)
    
    # Print the table
    print("\nTable of Error Rates for Different Operations:")
    print(table_data.to_string(index=False))
    
    # Create results directory if it doesn't exist
    print("\nCreating results directory...")
    results_dir = pathlib.Path('results')
    results_dir.mkdir(exist_ok=True)
    
    # Generate LaTeX table and save to file
    print("Generating LaTeX table...")
    latex_table = table_to_latex(table_data)
    table_path = results_dir / 'error_reduction_table.tex'
    
    with open(table_path, 'w') as f:
        f.write(latex_table)
    
    print(f"LaTeX table saved to {table_path}")
    
    # Write error reduction statistics to LaTeX file
    print("Writing error reduction statistics to LaTeX file...")
    stats_path = results_dir / 'error_reduction_statistics.tex'
    write_error_reduction_statistics(exch_results, adj_results, stats_path)
    print(f"Error reduction statistics saved to {stats_path}")


def main():
    """Main function to run the analysis."""
    print("=== Decision Tree Error Reduction Analysis ===")
    analyze_decision_tree_errors('results/consolidated_results.csv')
    print("\nAnalysis complete!")
    print("Results saved to the 'results/' directory")


if __name__ == "__main__":
    main()
