#!/usr/bin/env python3
"""
Analyze performance metrics for LinearizeLLM context experiment results.
"""

import json
import os
import glob
import argparse
from collections import defaultdict
import numpy as np

def categorize_problems_by_pattern():
    """Categorize problems by their nonlinearity patterns"""
    patterns = {
        'max': [
            'aircraft_monotone_max',
            'netasgn_nonlinear_max',
            'nltrans_nonlinear_max',
            'prod_nonlinear_max'
        ],
        'min': [
            'diet_problem_min_abs',
            'knapsack_problem_nonlinear_min_1',
            'knapsack_problem_nonlinear_min_2'
        ],
        'monotone': [
            'aircraft_monotone_max',
            'diet_problem_monotone'
        ],
        'abs_value': [
            'blend_problem_abs',
            'diet_problem_min_abs',
            'multi_nonlinear_abs',
            'netasgn_nonlinear_abs'
        ],
        'linear_fractional': [
            'blend_problem_frac',
            'diet_problem_nonlinear_frac',
            'netmcol_nonlinear_frac'
        ],
        'bilinear': [
            'media_selection_nonlinear_binbin',
            'media_selection_nonlinear_bincon',
            'netmcol_nonlinear_bincon',
            'nltrans_nonlinear_bincon',
            'prod_nonlinear_bincon',
            'revenue_maximization_nonlinear_bincon'
        ]
    }
    return patterns

def load_ground_truth_results():
    """Load ground truth optimization results"""
    ground_truth = {}
    try:
        with open('optimization_results.json', 'r') as f:
            data = json.load(f)
        
        if 'results' in data:
            for problem_name, problem_data in data['results'].items():
                if problem_data.get('success', False):
                    # Extract the optimal objective value
                    optimal_point = problem_data.get('optimal_point')
                    if optimal_point is not None and isinstance(optimal_point, (int, float)):
                        ground_truth[problem_name] = optimal_point
                    else:
                        # Fallback to objective_value if optimal_point is not available
                        obj_value = problem_data.get('objective_value', {})
                        if isinstance(obj_value, dict):
                            # Try to find a numeric value in the objective_value dict
                            for key, value in obj_value.items():
                                if isinstance(value, (int, float)):
                                    ground_truth[problem_name] = value
                                    break
                        elif isinstance(obj_value, (int, float)):
                            ground_truth[problem_name] = obj_value
    except Exception as e:
        print(f"Warning: Could not load ground truth results: {e}")
    
    return ground_truth

def load_context_experiment_results(results_dir):
    """Load context experiment results from the specified directory"""
    results = {}
    
    # Load ground truth results
    ground_truth = load_ground_truth_results()
    
    # Load the main summary file (try both the current directory and parent directory)
    summary_file = os.path.join(results_dir, 'context_experiment_summary.json')
    if not os.path.exists(summary_file):
        # Try parent directory (for backward compatibility)
        parent_dir = os.path.dirname(results_dir)
        summary_file = os.path.join(parent_dir, 'context_experiment_summary.json')
    
    if os.path.exists(summary_file):
        with open(summary_file, 'r') as f:
            summary_data = json.load(f)
        
        # Extract results from the results section
        if 'results' in summary_data:
            # Group results by problem name and scenario
            problem_stats = defaultdict(lambda: defaultdict(lambda: {
                'total_seeds': 0,
                'successful_runs': 0,
                'detection_errors': 0,
                'reformulation_errors': 0,
                'compilation_errors': 0,
                'objective_values': [],
                'ground_truth_obj': None
            }))
            
            for result_key, result_data in summary_data['results'].items():
                # Parse result key: problem_name_scenario_seed_X
                parts = result_key.split('_')
                if len(parts) >= 4:
                    # Find the scenario part (should be 'no_context' or 'partial_info')
                    scenario = None
                    problem_name = None
                    
                    # Check for 'no_context' (split as 'no' and 'context')
                    if 'no' in parts and 'context' in parts:
                        no_idx = parts.index('no')
                        context_idx = parts.index('context')
                        if context_idx == no_idx + 1:  # 'context' comes right after 'no'
                            scenario = 'no_context'
                            problem_name = '_'.join(parts[:no_idx])
                    # Check for 'partial_info' (split as 'partial' and 'info')
                    elif 'partial' in parts and 'info' in parts:
                        partial_idx = parts.index('partial')
                        info_idx = parts.index('info')
                        if info_idx == partial_idx + 1:  # 'info' comes right after 'partial'
                            scenario = 'partial_info'
                            problem_name = '_'.join(parts[:partial_idx])
                    
                    if scenario and problem_name:
                        # Process results for both scenarios when analyzing the full model directory
                        current_scenario = os.path.basename(results_dir)
                        # If we're in a model directory (like 'o3'), process all scenarios
                        # If we're in a specific scenario directory, only process that scenario
                        if current_scenario in ['o3', 'gemini-2.5-flash'] or scenario == current_scenario:
                            problem_stats[problem_name][scenario]['total_seeds'] += 1
                            
                            # Set ground truth objective value for this problem
                            if problem_name in ground_truth:
                                problem_stats[problem_name][scenario]['ground_truth_obj'] = ground_truth[problem_name]
                            
                            # Check if the run was successful
                            opt_results = result_data.get('optimization_results', {})
                            if opt_results and opt_results.get('success', False):
                                # Check if optimization actually found a feasible solution
                                inner_opt_results = opt_results.get('optimization_results', {})
                                if inner_opt_results and inner_opt_results.get('status') == 'optimal':
                                    # Check model type - only LP/MILP are considered successful reformulations
                                    model_type = inner_opt_results.get('model_type', 'unknown').upper()
                                    if model_type in ('LP', 'MILP'):
                                        problem_stats[problem_name][scenario]['successful_runs'] += 1
                                        if 'objective_value' in inner_opt_results:
                                            problem_stats[problem_name][scenario]['objective_values'].append(inner_opt_results['objective_value'])
                                    else:
                                        # Model is not LP/MILP (e.g., MIQCP, QP, etc.) - REFORMULATION ERROR
                                        # According to definitions: "The final model is not classified as an LP or MILP by Gurobi"
                                        problem_stats[problem_name][scenario]['reformulation_errors'] += 1
                                else:
                                    # Success = true but optimization failed (infeasible, unbounded, etc.)
                                    # According to definitions: "The model is infeasible or unbounded when solved"
                                    # This is a REFORMULATION ERROR
                                    problem_stats[problem_name][scenario]['reformulation_errors'] += 1
                            else:
                                # If not successful, check error type hierarchically
                                 error_type = opt_results.get('error_type', 'UNKNOWN')
                                 if error_type == 'DETECTION_ERROR':
                                     # Detection error - only count as detection error
                                     problem_stats[problem_name][scenario]['detection_errors'] += 1
                                 elif error_type == 'REFORMULATION_ERROR':
                                     # Reformulation error - only count as reformulation error (detection succeeded)
                                     problem_stats[problem_name][scenario]['reformulation_errors'] += 1
                                 elif error_type == 'NAME_ERROR':
                                     # NAME_ERROR (undefined variable) = REFORMULATION ERROR
                                     # The LLM generated code with undefined variables, indicating incorrect reformulation
                                     problem_stats[problem_name][scenario]['reformulation_errors'] += 1
                                 elif error_type in ['RUNTIME_ERROR', 'SYNTAX_ERROR', 'COMPILATION_ERROR']:
                                     # Runtime/syntax/compilation errors = COMPILATION ERROR
                                     # These are actual compilation/runtime failures, not reformulation logic issues
                                     problem_stats[problem_name][scenario]['compilation_errors'] += 1
                                 else:
                                     # Default to detection error for unknown error types
                                     problem_stats[problem_name][scenario]['detection_errors'] += 1
            
            # Convert to the expected format - flatten the nested structure
            for problem_name, scenario_stats in problem_stats.items():
                for scenario, stats in scenario_stats.items():
                    results[f"{problem_name}_{scenario}"] = stats
    
    return results

def calculate_metrics_for_pattern(pattern_name, pattern_problems, experiment_results):
    """Calculate metrics for a specific nonlinearity pattern"""
    total_runs = 0
    detection_successes = 0
    reformulation_successes = 0
    compilation_successes = 0
    overall_successes = 0
    
    pattern_results = []
    
    for problem_name in pattern_problems:
        # Look for problems with scenario suffixes
        matching_keys = [k for k in experiment_results.keys() if k.startswith(problem_name + '_')]
        if matching_keys:
            # Use the first matching key (should be the only one for each problem)
            problem_data = experiment_results[matching_keys[0]]
            
            total_seeds = problem_data.get('total_seeds', 0)
            successful_runs = problem_data.get('successful_runs', 0)
            detection_errors = problem_data.get('detection_errors', 0)
            reformulation_errors = problem_data.get('reformulation_errors', 0)
            compilation_errors = problem_data.get('compilation_errors', 0)
            
            # Calculate metrics for this problem
            runs_with_detection_success = total_seeds - detection_errors
            runs_with_reformulation_success = total_seeds - reformulation_errors
            runs_with_compilation_success = total_seeds - compilation_errors
            
            # Calculate accuracy: compare reformulated objective values with ground truth
            correct_objective_runs = 0
            ground_truth_obj = problem_data.get('ground_truth_obj')
            if ground_truth_obj is not None and successful_runs > 0:
                objective_values = problem_data.get('objective_values', [])
                for obj_val in objective_values:
                    if obj_val is not None:
                        # Compare with tolerance ε=10⁻⁴ as per definitions
                        if abs(obj_val - ground_truth_obj) < 1e-4:
                            correct_objective_runs += 1
            
            runs_with_overall_success = correct_objective_runs
            
            # Accumulate metrics
            total_runs += total_seeds
            detection_successes += runs_with_detection_success
            reformulation_successes += runs_with_reformulation_success
            compilation_successes += runs_with_compilation_success
            overall_successes += runs_with_overall_success
            
            pattern_results.append({
                'problem_name': problem_name,
                'total_seeds': total_seeds,
                'successful_runs': successful_runs,
                'detection_errors': detection_errors,
                'reformulation_errors': reformulation_errors,
                'compilation_errors': compilation_errors,
                'correct_objective_runs': correct_objective_runs,
                'ground_truth_obj': ground_truth_obj
            })
    
    # Calculate rates
    dsr = detection_successes / total_runs if total_runs > 0 else 0
    rsr = reformulation_successes / total_runs if total_runs > 0 else 0
    csr = compilation_successes / total_runs if total_runs > 0 else 0
    accuracy = overall_successes / total_runs if total_runs > 0 else 0
    
    return {
        'pattern_name': pattern_name,
        'total_runs': total_runs,
        'detection_successes': detection_successes,
        'reformulation_successes': reformulation_successes,
        'compilation_successes': compilation_successes,
        'overall_successes': overall_successes,
        'dsr': dsr,
        'rsr': rsr,
        'csr': csr,
        'accuracy': accuracy,
        'problem_results': pattern_results
    }

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Analyze performance metrics for LinearizeLLM context experiment results')
    parser.add_argument('--results-dir', type=str, required=True,
                       help='Directory containing context experiment results')
    parser.add_argument('--output-file', type=str, default='context_metrics_results.json',
                       help='Output file for detailed results')
    args = parser.parse_args()
    
    # Load data
    patterns = categorize_problems_by_pattern()
    experiment_results = load_context_experiment_results(args.results_dir)
    
    print(f"Context Experiment Performance Metrics Analysis for {args.results_dir}")
    print("=" * 60)
    
    # Analyze each scenario separately
    scenarios = ['no_context', 'partial_info']
    all_scenario_results = {}
    
    for scenario in scenarios:
        print(f"\n{'='*60}")
        print(f"SCENARIO: {scenario.upper()}")
        print(f"{'='*60}")
        
        all_results = []
        total_overall_runs = 0
        total_overall_detection_successes = 0
        total_overall_reformulation_successes = 0
        total_overall_compilation_successes = 0
        total_overall_successes = 0
        
        # Filter results for this scenario
        scenario_results = {k: v for k, v in experiment_results.items() if k.endswith(f'_{scenario}')}
        
        for pattern_name, pattern_problems in patterns.items():
            print(f"\nPattern: {pattern_name}")
            print("-" * 30)
            
            result = calculate_metrics_for_pattern(pattern_name, pattern_problems, scenario_results)
            all_results.append(result)
            
            print(f"Problems: {', '.join(pattern_problems)}")
            print(f"Total runs: {result['total_runs']}")
            print(f"Detection Success Rate (DSR): {result['dsr']:.3f} ({result['detection_successes']}/{result['total_runs']})")
            print(f"Reformulation Success Rate (RSR): {result['rsr']:.3f} ({result['reformulation_successes']}/{result['total_runs']})")
            print(f"Compilation Success Rate (CSR): {result['csr']:.3f} ({result['compilation_successes']}/{result['total_runs']})")
            print(f"Accuracy: {result['accuracy']:.3f} ({result['overall_successes']}/{result['total_runs']})")
            
            # Accumulate for overall metrics
            total_overall_runs += result['total_runs']
            total_overall_detection_successes += result['detection_successes']
            total_overall_reformulation_successes += result['reformulation_successes']
            total_overall_compilation_successes += result['compilation_successes']
            total_overall_successes += result['overall_successes']
    
        # Calculate overall metrics for this scenario
        if total_overall_runs > 0:
            print("\n" + "=" * 60)
            print(f"OVERALL METRICS - {scenario.upper()}")
            print("=" * 60)
            print(f"Total runs across all patterns: {total_overall_runs}")
            print(f"Overall Detection Success Rate (DSR): {total_overall_detection_successes/total_overall_runs:.3f}")
            print(f"Overall Reformulation Success Rate (RSR): {total_overall_reformulation_successes/total_overall_runs:.3f}")
            print(f"Overall Compilation Success Rate (CSR): {total_overall_compilation_successes/total_overall_runs:.3f}")
            print(f"Overall Accuracy: {total_overall_successes/total_overall_runs:.3f}")
        
        # Store results for this scenario
        all_scenario_results[scenario] = {
            'pattern_results': all_results,
            'overall_metrics': {
                'total_runs': total_overall_runs,
                'detection_successes': total_overall_detection_successes,
                'reformulation_successes': total_overall_reformulation_successes,
                'compilation_successes': total_overall_compilation_successes,
                'overall_successes': total_overall_successes,
                'dsr': total_overall_detection_successes/total_overall_runs if total_overall_runs > 0 else 0,
                'rsr': total_overall_reformulation_successes/total_overall_runs if total_overall_runs > 0 else 0,
                'csr': total_overall_compilation_successes/total_overall_runs if total_overall_runs > 0 else 0,
                'accuracy': total_overall_successes/total_overall_runs if total_overall_runs > 0 else 0
            }
        }
    
    # Save detailed results for all scenarios
    with open(args.output_file, 'w') as f:
        json.dump(all_scenario_results, f, indent=2)
    
    print(f"\nDetailed results saved to '{args.output_file}'")

if __name__ == "__main__":
    main() 