#!/usr/bin/env python3
"""
Corrected baseline analysis with proper error classification logic for LinearizeLLM.
"""

import json
import os
import glob
import argparse
from collections import defaultdict
import numpy as np

def load_optimization_results():
    """Load the ground truth optimization results"""
    with open('optimization_results.json', 'r') as f:
        return json.load(f)

def categorize_problems_by_pattern():
    """Categorize problems by their nonlinearity patterns"""
    patterns = {
        'max': [
            'aircraft_monotone_max',
            'netasgn_nonlinear_max',
            'nltrans_nonlinear_max',
            'prod_nonlinear_max'
        ],
        'min': [
            'diet_problem_min_abs',
            'knapsack_problem_nonlinear_min_1',
            'knapsack_problem_nonlinear_min_2'
        ],
        'monotone': [
            'aircraft_monotone_max',
            'diet_problem_monotone'
        ],
        'abs_value': [
            'blend_problem_abs',
            'diet_problem_min_abs',
            'multi_nonlinear_abs',
            'netasgn_nonlinear_abs'
        ],
        'linear_fractional': [
            'blend_problem_frac',
            'diet_problem_nonlinear_frac',
            'netmcol_nonlinear_frac'
        ],
        'bilinear': [
            'media_selection_nonlinear_binbin',
            'media_selection_nonlinear_bincon',
            'netmcol_nonlinear_bincon',
            'nltrans_nonlinear_bincon',
            'prod_nonlinear_bincon',
            'revenue_maximization_nonlinear_bincon'
        ]
    }
    return patterns

def load_experiment_results(results_dir):
    """Load all experiment results from the specified results directory"""
    results = {}
    pattern_dirs = glob.glob(os.path.join(results_dir, '*'))
    
    # Mapping from directory names to problem names
    dir_to_problem_mapping = {
        'diet_problem_monotone': 'diet_problem_monotone',
        'aircraft_monotone_max': 'aircraft_monotone_max',
        'diet_problem_min_abs': 'diet_problem_min_abs',
        'prod_nonlinear_max': 'prod_nonlinear_max',
        'prod_nonlinear_max_results': 'prod_nonlinear_max',  # Add mapping for _results suffix
        'netmcol_nonlinear_bincon': 'netmcol_nonlinear_bincon',
        'netmcol_nonlinear_bincon_results': 'netmcol_nonlinear_bincon',  # Add mapping for _results suffix
        'nltrans_nonlinear_max': 'nltrans_nonlinear_max',
        'diet_problem_nonlinear_frac': 'diet_problem_nonlinear_frac',
        'netmcol_nonlinear_frac': 'netmcol_nonlinear_frac',
        'knapsack_problem_nonlinear_min_2': 'knapsack_problem_nonlinear_min_2',
        'knapsack_problem_nonlinear_min_1': 'knapsack_problem_nonlinear_min_1',
        'netasgn_nonlinear_max': 'netasgn_nonlinear_max',
        'revenue_maximization_nonlinear_bincon': 'revenue_maximization_nonlinear_bincon',
        'revenue_maximization_nonlinear_bincon_results': 'revenue_maximization_nonlinear_bincon',  # Add mapping for _results suffix
        'prod_nonlinear_bincon': 'prod_nonlinear_bincon',
        'nltrans_nonlinear_bincon': 'nltrans_nonlinear_bincon',
        'netasgn_nonlinear_abs': 'netasgn_nonlinear_abs',
        'netasgn_nonlinear_abs_results': 'netasgn_nonlinear_abs',  # Add mapping for _results suffix
        'multi_nonlinear_abs': 'multi_nonlinear_abs',
        'media_selection_nonlinear_binbin': 'media_selection_nonlinear_binbin',
        'media_selection_nonlinear_bincon': 'media_selection_nonlinear_bincon',
        'blend_problem_frac': 'blend_problem_frac',
        'blend_problem_abs': 'blend_problem_abs'
    }
    
    for pattern_dir in pattern_dirs:
        if os.path.isdir(pattern_dir):
            dir_name = os.path.basename(pattern_dir)
            summary_file = os.path.join(pattern_dir, 'experiment_summary.json')
            
            if os.path.exists(summary_file):
                with open(summary_file, 'r') as f:
                    data = json.load(f)
                    
                # Map directory name to problem name
                problem_name = dir_to_problem_mapping.get(dir_name, dir_name)
                results[problem_name] = data
    
    # Also load results from Some_older_results_I_guess directory
    older_results_dir = os.path.join(results_dir, 'Some_older_results_I_guess')
    if os.path.exists(older_results_dir):
        older_summary_file = os.path.join(older_results_dir, 'experiment_summary.json')
        if os.path.exists(older_summary_file):
            with open(older_summary_file, 'r') as f:
                older_data = json.load(f)
            
            # Extract individual problem results from the older results
            if 'problem_summary' in older_data:
                for problem_name, problem_data in older_data['problem_summary'].items():
                    # Create a new experiment result structure for each problem
                    problem_experiment_result = {
                        'experiment_info': {
                            'total_problems': 1,
                            'total_seeds': problem_data.get('total_seeds', 0),
                            'total_runs': problem_data.get('total_seeds', 0)
                        },
                        'error_statistics': {
                            'detection_errors': problem_data.get('detection_errors', 0),
                            'reformulation_errors': problem_data.get('reformulation_errors', 0),
                            'compilation_errors': problem_data.get('compilation_errors', 0),
                            'successful_runs': problem_data.get('successful_runs', 0)
                        },
                        'detection_accuracy': problem_data.get('detection_accuracy', {}),
                        'problem_summary': {
                            problem_name: problem_data
                        }
                    }
                    results[problem_name] = problem_experiment_result
    
    return results

def extract_objective_value(ground_truth_data):
    """Extract the objective value from ground truth data"""
    if ground_truth_data is None:
        return None
    
    # Try to get the optimal_point first
    optimal_point = ground_truth_data.get('optimal_point')
    
    if optimal_point is not None:
        if isinstance(optimal_point, (int, float)):
            return optimal_point
        elif isinstance(optimal_point, dict):
            # For problems with multiple variables, we might need to extract a specific value
            # For now, return the first numeric value we find
            for key, value in optimal_point.items():
                if isinstance(value, (int, float)):
                    return value
        elif isinstance(optimal_point, list):
            # For list of values, return the first numeric value
            for value in optimal_point:
                if isinstance(value, (int, float)):
                    return value
    
    # If optimal_point doesn't work, try objective_value
    objective_value = ground_truth_data.get('objective_value')
    if objective_value is not None:
        if isinstance(objective_value, (int, float)):
            return objective_value
        elif isinstance(objective_value, dict):
            # Look for a key that might contain the objective value
            for key, value in objective_value.items():
                if isinstance(value, (int, float)):
                    return value
    
    return None

def calculate_metrics_for_pattern(pattern_name, pattern_problems, experiment_results, ground_truth):
    """Calculate metrics for a specific nonlinearity pattern with corrected error classification"""
    total_runs = 0
    detection_successes = 0
    reformulation_successes = 0
    compilation_successes = 0
    overall_successes = 0
    
    pattern_results = []
    
    for problem_name in pattern_problems:
        if problem_name in experiment_results:
            problem_data = experiment_results[problem_name]['problem_summary'].get(problem_name, {})
            
            total_seeds = problem_data.get('total_seeds', 0)
            successful_runs = problem_data.get('successful_runs', 0)
            detection_errors = problem_data.get('detection_errors', 0)
            reformulation_errors = problem_data.get('reformulation_errors', 0)
            compilation_errors = problem_data.get('compilation_errors', 0)
            
            # Get ground truth objective value
            ground_truth_obj = None
            if problem_name in ground_truth['results']:
                ground_truth_obj = extract_objective_value(ground_truth['results'][problem_name])
            
            # Calculate metrics for this problem
            runs_with_detection_success = total_seeds - detection_errors
            runs_with_reformulation_success = total_seeds - reformulation_errors
            runs_with_compilation_success = total_seeds - compilation_errors
            
            # Check objective value accuracy for successful runs
            correct_objective_runs = 0
            if successful_runs > 0 and ground_truth_obj is not None:
                objective_values = problem_data.get('objective_values', [])
                for obj_val in objective_values:
                    if obj_val is not None and ground_truth_obj is not None:
                        # Use tolerance of 0.001 for comparison
                        if abs(obj_val - ground_truth_obj) < 0.001:
                            correct_objective_runs += 1
            
            runs_with_overall_success = correct_objective_runs
            
            # Accumulate metrics
            total_runs += total_seeds
            detection_successes += runs_with_detection_success
            reformulation_successes += runs_with_reformulation_success
            compilation_successes += runs_with_compilation_success
            overall_successes += runs_with_overall_success
            
            pattern_results.append({
                'problem_name': problem_name,
                'total_seeds': total_seeds,
                'successful_runs': successful_runs,
                'detection_errors': detection_errors,
                'reformulation_errors': reformulation_errors,
                'compilation_errors': compilation_errors,
                'correct_objective_runs': correct_objective_runs,
                'ground_truth_obj': ground_truth_obj
            })
    
    # Calculate rates
    dsr = detection_successes / total_runs if total_runs > 0 else 0
    rsr = reformulation_successes / total_runs if total_runs > 0 else 0
    csr = compilation_successes / total_runs if total_runs > 0 else 0
    accuracy = overall_successes / total_runs if total_runs > 0 else 0
    
    return {
        'pattern_name': pattern_name,
        'total_runs': total_runs,
        'detection_successes': detection_successes,
        'reformulation_successes': reformulation_successes,
        'compilation_successes': compilation_successes,
        'overall_successes': overall_successes,
        'dsr': dsr,
        'rsr': rsr,
        'csr': csr,
        'accuracy': accuracy,
        'problem_results': pattern_results
    }

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Analyze performance metrics for LinearizeLLM baseline experiments with corrected error classification')
    parser.add_argument('--results-dir', type=str, default='data/results_gemini_2_5_flash', 
                       help='Directory containing experiment results (default: data/results_gemini_2_5_flash)')
    parser.add_argument('--output-file', type=str, default='gemini_baseline_metrics_corrected.json',
                       help='Output file for detailed results (default: gemini_baseline_metrics_corrected.json)')
    args = parser.parse_args()
    
    # Load data
    ground_truth = load_optimization_results()
    patterns = categorize_problems_by_pattern()
    experiment_results = load_experiment_results(args.results_dir)
    
    print(f"Corrected Baseline Performance Metrics Analysis for {args.results_dir}")
    print("=" * 50)
    
    all_results = []
    total_overall_runs = 0
    total_overall_detection_successes = 0
    total_overall_reformulation_successes = 0
    total_overall_compilation_successes = 0
    total_overall_successes = 0
    
    for pattern_name, pattern_problems in patterns.items():
        print(f"\nPattern: {pattern_name}")
        print("-" * 30)
        
        result = calculate_metrics_for_pattern(pattern_name, pattern_problems, experiment_results, ground_truth)
        all_results.append(result)
        
        print(f"Problems: {', '.join(pattern_problems)}")
        print(f"Total runs: {result['total_runs']}")
        print(f"Detection Success Rate (DSR): {result['dsr']:.3f} ({result['detection_successes']}/{result['total_runs']})")
        print(f"Reformulation Success Rate (RSR): {result['rsr']:.3f} ({result['reformulation_successes']}/{result['total_runs']})")
        print(f"Compilation Success Rate (CSR): {result['csr']:.3f} ({result['compilation_successes']}/{result['total_runs']})")
        print(f"Accuracy: {result['accuracy']:.3f} ({result['overall_successes']}/{result['total_runs']})")
        
        # Accumulate for overall metrics
        total_overall_runs += result['total_runs']
        total_overall_detection_successes += result['detection_successes']
        total_overall_reformulation_successes += result['reformulation_successes']
        total_overall_compilation_successes += result['compilation_successes']
        total_overall_successes += result['overall_successes']
    
    # Calculate overall metrics
    print("\n" + "=" * 50)
    print("OVERALL METRICS")
    print("=" * 50)
    print(f"Total runs across all patterns: {total_overall_runs}")
    print(f"Overall Detection Success Rate (DSR): {total_overall_detection_successes/total_overall_runs:.3f}")
    print(f"Overall Reformulation Success Rate (RSR): {total_overall_reformulation_successes/total_overall_runs:.3f}")
    print(f"Overall Compilation Success Rate (CSR): {total_overall_compilation_successes/total_overall_runs:.3f}")
    print(f"Overall Accuracy: {total_overall_successes/total_overall_runs:.3f}")
    
    # Save detailed results
    detailed_results = {
        'results_directory': args.results_dir,
        'pattern_results': all_results,
        'overall_metrics': {
            'total_runs': total_overall_runs,
            'detection_successes': total_overall_detection_successes,
            'reformulation_successes': total_overall_reformulation_successes,
            'compilation_successes': total_overall_compilation_successes,
            'overall_successes': total_overall_successes,
            'dsr': total_overall_detection_successes/total_overall_runs,
            'rsr': total_overall_reformulation_successes/total_overall_runs,
            'csr': total_overall_compilation_successes/total_overall_runs,
            'accuracy': total_overall_successes/total_overall_runs
        }
    }
    
    with open(args.output_file, 'w') as f:
        json.dump(detailed_results, f, indent=2)
    
    print(f"\nDetailed results saved to '{args.output_file}'")

if __name__ == "__main__":
    main() 