import sys
import os
import csv
import statistics  # For calculating median

# Add the local-searcher subdirectory to Python's module search path
sys.path.append('local-searcher')
from weka_parser import parse_tree

def analyze_tree(tree):
    if tree is None or tree.root is None:
        return {
            'size': 0,
            'num_dimensions': 0,
            'max_distinct_thresholds': 0,
            'max_dimensions_path': 0,
            'max_thresholds_path': 0
        }
    dimensions = set()
    thresholds = {}
    max_dimensions_path = 0
    max_thresholds_path = 0
    
    def dfs(node, current_dimensions, current_thresholds):
        nonlocal max_dimensions_path, max_thresholds_path
        
        if node.is_leaf():
            max_dimensions_path = max(max_dimensions_path, len(current_dimensions))
            max_thresholds_path = max(max_thresholds_path, len(current_thresholds))
            return

        dimensions.add(node.dimension)
        if node.dimension not in thresholds:
            thresholds[node.dimension] = set()
        thresholds[node.dimension].add(node.threshold)

        new_dimensions = current_dimensions.union({node.dimension})
        new_thresholds = current_thresholds.union({node.threshold})

        dfs(node.left, new_dimensions, new_thresholds)
        dfs(node.right, new_dimensions, new_thresholds)
    
    dfs(tree.root, set(), set())

    return {
        'size': len(tree.inner_nodes),
        'num_dimensions': len(dimensions),
        'max_distinct_thresholds': max(len(t) for t in thresholds.values()) if thresholds else 0,
        'max_dimensions_path': max_dimensions_path,
        'max_thresholds_path': max_thresholds_path
    }

def get_classification_errors(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
        return int(lines[-2].strip())

def calculate_stats(data_list):
    """Calculate min, max, mean, and median for a list of values."""
    if not data_list:
        return {'min': 0, 'max': 0, 'mean': 0, 'median': 0}
    
    return {
        'min': min(data_list),
        'max': max(data_list),
        'mean': sum(data_list) / len(data_list),
        'median': statistics.median(data_list)
    }

def generate_latex_file(stats, output_file):
    """Generate a LaTeX file with summary statistics."""
    with open(output_file, 'w') as f:
        f.write("% Tree statistics - generated by compute-tree-statistics.py\n")
        
        # Write the total number of datasets
        f.write(f"\\newcommand{{\\totalTrees}}{{{len(stats['Tree'])}}}\n\n")
        
        # Mapping of metrics to LaTeX command names
        metric_to_command = {
            'Tree Size': 'TreeSize',
            'Num Dimensions': 'Dimensions',
            'Max Distinct Thresholds': 'DistinctThresholds', 
            'Max Dimensions on Path': 'DimensionsPath',
            'Max Thresholds on Path': 'ThresholdsPath',
            'Classification Errors': 'Errors'
        }
        
        # Write statistics for each metric
        for metric, cmd_name in metric_to_command.items():
            f.write(f"\\newcommand{{\\min{cmd_name}}}{{{stats[metric]['min']:.0f}}}\n")
            f.write(f"\\newcommand{{\\max{cmd_name}}}{{{stats[metric]['max']:.0f}}}\n")
            f.write(f"\\newcommand{{\\mean{cmd_name}}}{{{stats[metric]['mean']:.2f}}}\n")
            f.write(f"\\newcommand{{\\median{cmd_name}}}{{{stats[metric]['median']:.2f}}}\n")
            f.write("\n")

def main():
    weka_trees_dir = 'results/weka-trees/'
    output_csv = 'results/tree_analysis_results.csv'
    output_latex = 'results/tree_statistics.tex'
    results = []
    
    for filename in os.listdir(weka_trees_dir):
        file_path = os.path.join(weka_trees_dir, filename)
        if os.path.isfile(file_path):
            tree = parse_tree(file_path)
            analysis = analyze_tree(tree)
            errors = get_classification_errors(file_path)
            results.append({
                'Tree': filename,
                'Tree Size': analysis['size'],
                'Num Dimensions': analysis['num_dimensions'],
                'Max Distinct Thresholds': analysis['max_distinct_thresholds'],
                'Max Dimensions on Path': analysis['max_dimensions_path'],
                'Max Thresholds on Path': analysis['max_thresholds_path'],
                'Classification Errors': errors
            })
    
    # Write individual results to CSV
    with open(output_csv, 'w', newline='') as csvfile:
        fieldnames = ['Tree', 'Tree Size', 'Num Dimensions', 'Max Distinct Thresholds', 
                      'Max Dimensions on Path', 'Max Thresholds on Path', 'Classification Errors']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    
    # Calculate summary statistics
    stats = {'Tree': [row['Tree'] for row in results]}
    for metric in ['Tree Size', 'Num Dimensions', 'Max Distinct Thresholds', 
                  'Max Dimensions on Path', 'Max Thresholds on Path', 'Classification Errors']:
        values = [row[metric] for row in results]
        stats[metric] = calculate_stats(values)
    
    # Generate LaTeX file
    generate_latex_file(stats, output_latex)
    
    print(f"Analysis results have been written to {output_csv}")
    print(f"Summary statistics have been written to {output_latex}")

if __name__ == "__main__":
    main()
