#!/usr/bin/env python3
"""
Analyze context experiment results for both O3 and Gemini models.
"""

import json
import os
import argparse
import glob
from collections import defaultdict
import numpy as np

def load_optimization_results():
    """Load the ground truth optimization results"""
    with open('optimization_results.json', 'r') as f:
        return json.load(f)

def categorize_problems_by_pattern():
    """Categorize problems by their nonlinearity patterns"""
    patterns = {
        'max': [
            'aircraft_monotone_max',
            'netasgn_nonlinear_max',
            'nltrans_nonlinear_max',
            'prod_nonlinear_max'
        ],
        'min': [
            'diet_problem_min_abs',
            'knapsack_problem_nonlinear_min_1',
            'knapsack_problem_nonlinear_min_2'
        ],
        'monotone': [
            'aircraft_monotone_max',
            'diet_problem_monotone'
        ],
        'abs_value': [
            'blend_problem_abs',
            'diet_problem_min_abs',
            'multi_nonlinear_abs',
            'netasgn_nonlinear_abs'
        ],
        'linear_fractional': [
            'blend_problem_frac',
            'diet_problem_nonlinear_frac',
            'netmcol_nonlinear_frac'
        ],
        'bilinear': [
            'media_selection_nonlinear_binbin',
            'media_selection_nonlinear_bincon',
            'netmcol_nonlinear_bincon',
            'nltrans_nonlinear_bincon',
            'prod_nonlinear_bincon',
            'revenue_maximization_nonlinear_bincon'
        ]
    }
    return patterns

def load_context_experiment_results(results_dir, model_name):
    """Load context experiment results for a specific model"""
    # Try to load from summary file first (O3 structure)
    summary_file = os.path.join(results_dir, model_name, 'context_experiment_summary.json')
    
    if os.path.exists(summary_file):
        print(f"Loading {model_name} results from summary file...")
        with open(summary_file, 'r') as f:
            return json.load(f)
    
    # If no summary file, load from individual directories (Gemini structure)
    print(f"Loading {model_name} results from individual directories...")
    results = {'results': {}}
    
    for scenario in ['no_context', 'partial_info']:
        scenario_dir = os.path.join(results_dir, model_name, scenario)
        if not os.path.exists(scenario_dir):
            continue
            
        # Get all problem directories
        problem_dirs = [d for d in os.listdir(scenario_dir) if os.path.isdir(os.path.join(scenario_dir, d))]
        
        for problem_name in problem_dirs:
            problem_dir = os.path.join(scenario_dir, problem_name)
            
            # Look for experiment result files in seed subdirectories
            for seed in [1, 2, 3, 4, 5]:
                experiment_key = f"{problem_name}_{scenario}_seed_{seed}"
                seed_dir = os.path.join(problem_dir, f"seed_{seed}")
                
                if os.path.exists(seed_dir):
                    # Look for result files in the seed directory
                    result_file = os.path.join(seed_dir, "context_experiment_results.json")
                    error_file = os.path.join(seed_dir, "context_experiment_error.json")
                    
                    if os.path.exists(result_file):
                        with open(result_file, 'r') as f:
                            experiment_result = json.load(f)
                            results['results'][experiment_key] = experiment_result
                    elif os.path.exists(error_file):
                        with open(error_file, 'r') as f:
                            error_result = json.load(f)
                            results['results'][experiment_key] = {'error': error_result}
    
    return results

def extract_objective_value(ground_truth_data):
    """Extract the objective value from ground truth data"""
    if ground_truth_data is None:
        return None
    
    # Try to get the optimal_point first
    optimal_point = ground_truth_data.get('optimal_point')
    
    if optimal_point is not None:
        if isinstance(optimal_point, (int, float)):
            return optimal_point
        elif isinstance(optimal_point, dict):
            # For problems with multiple variables, we might need to extract a specific value
            # For now, return the first numeric value we find
            for key, value in optimal_point.items():
                if isinstance(value, (int, float)):
                    return value
        elif isinstance(optimal_point, list):
            # For list of values, return the first numeric value
            for value in optimal_point:
                if isinstance(value, (int, float)):
                    return value
    
    # If optimal_point doesn't work, try objective_value
    objective_value = ground_truth_data.get('objective_value')
    if objective_value is not None:
        if isinstance(objective_value, (int, float)):
            return objective_value
        elif isinstance(objective_value, dict):
            # Look for a key that might contain the objective value
            for key, value in objective_value.items():
                if isinstance(value, (int, float)):
                    return value
    
    return None

def analyze_context_experiment_results(results_data, ground_truth, patterns):
    """Analyze context experiment results"""
    scenario_results = {}
    
    for scenario in ['no_context', 'partial_info']:
        scenario_data = {}
        
        for pattern_name, pattern_problems in patterns.items():
            total_runs = 0
            detection_successes = 0
            reformulation_successes = 0
            compilation_successes = 0
            overall_successes = 0
            
            pattern_results = []
            
            for problem_name in pattern_problems:
                problem_runs = 0
                problem_detection_successes = 0
                problem_reformulation_successes = 0
                problem_compilation_successes = 0
                problem_overall_successes = 0
                
                # Get ground truth objective value
                ground_truth_obj = None
                if problem_name in ground_truth['results']:
                    ground_truth_obj = extract_objective_value(ground_truth['results'][problem_name])
                
                # Analyze each seed for this problem
                for seed in [1, 2, 3, 4, 5]:
                    experiment_key = f"{problem_name}_{scenario}_seed_{seed}"
                    
                    if experiment_key in results_data['results']:
                        experiment_result = results_data['results'][experiment_key]
                        problem_runs += 1
                        total_runs += 1
                        
                        # Check if experiment was successful
                        if 'error' not in experiment_result:
                            # Check detection success (if patterns were extracted)
                            if 'extracted_patterns' in experiment_result:
                                detection_successes += 1
                                problem_detection_successes += 1
                            
                            # Check reformulation success (if linearized model was generated)
                            if 'linearized_model' in experiment_result and experiment_result['linearized_model']:
                                reformulation_successes += 1
                                problem_reformulation_successes += 1
                            
                            # Check compilation success (if optimization results exist)
                            if 'optimization_results' in experiment_result and experiment_result['optimization_results'].get('success', False):
                                compilation_successes += 1
                                problem_compilation_successes += 1
                                
                                # Check objective value accuracy
                                if ground_truth_obj is not None:
                                    opt_results = experiment_result['optimization_results'].get('optimization_results', {})
                                    if opt_results.get('status') == 'optimal':
                                        obj_val = opt_results.get('objective_value')
                                        if obj_val is not None and abs(obj_val - ground_truth_obj) < 0.001:
                                            overall_successes += 1
                                            problem_overall_successes += 1
                
                pattern_results.append({
                    'problem_name': problem_name,
                    'total_seeds': problem_runs,
                    'successful_runs': problem_compilation_successes,
                    'detection_errors': problem_runs - problem_detection_successes,
                    'reformulation_errors': problem_runs - problem_reformulation_successes,
                    'compilation_errors': problem_runs - problem_compilation_successes,
                    'correct_objective_runs': problem_overall_successes,
                    'ground_truth_obj': ground_truth_obj
                })
            
            # Calculate rates
            dsr = detection_successes / total_runs if total_runs > 0 else 0
            rsr = reformulation_successes / total_runs if total_runs > 0 else 0
            csr = compilation_successes / total_runs if total_runs > 0 else 0
            accuracy = overall_successes / total_runs if total_runs > 0 else 0
            
            scenario_data[pattern_name] = {
                'pattern_name': pattern_name,
                'total_runs': total_runs,
                'detection_successes': detection_successes,
                'reformulation_successes': reformulation_successes,
                'compilation_successes': compilation_successes,
                'overall_successes': overall_successes,
                'dsr': dsr,
                'rsr': rsr,
                'csr': csr,
                'accuracy': accuracy,
                'problem_results': pattern_results
            }
        
        scenario_results[scenario] = scenario_data
    
    return scenario_results

def main():
    parser = argparse.ArgumentParser(description='Analyze context experiment results')
    parser.add_argument('--results-dir', type=str, default='data/context_experiment_results',
                       help='Directory containing context experiment results')
    parser.add_argument('--output-file', type=str, default='context_experiment_analysis.json',
                       help='Output file for analysis results')
    args = parser.parse_args()
    
    # Load data
    ground_truth = load_optimization_results()
    patterns = categorize_problems_by_pattern()
    
    # Analyze both models
    all_results = {}
    
    for model_name in ['o3', 'gemini-2.5-flash']:
        print(f"Analyzing {model_name} context experiment results...")
        
        results_data = load_context_experiment_results(args.results_dir, model_name)
        if not results_data or not results_data['results']:
            print(f"No results found for {model_name}")
            continue
        
        model_results = analyze_context_experiment_results(results_data, ground_truth, patterns)
        all_results[model_name] = model_results
    
    # Save results
    with open(args.output_file, 'w') as f:
        json.dump(all_results, f, indent=2)
    
    print(f"Analysis results saved to {args.output_file}")
    
    # Print summary
    print("\n" + "="*80)
    print("CONTEXT EXPERIMENT ANALYSIS SUMMARY")
    print("="*80)
    
    for model_name, model_results in all_results.items():
        print(f"\n{model_name.upper()} RESULTS:")
        print("-" * 40)
        
        for scenario, scenario_data in model_results.items():
            print(f"\n{scenario.replace('_', ' ').title()}:")
            
            total_runs = 0
            total_detection_successes = 0
            total_reformulation_successes = 0
            total_compilation_successes = 0
            total_overall_successes = 0
            
            for pattern_name, pattern_data in scenario_data.items():
                total_runs += pattern_data['total_runs']
                total_detection_successes += pattern_data['detection_successes']
                total_reformulation_successes += pattern_data['reformulation_successes']
                total_compilation_successes += pattern_data['compilation_successes']
                total_overall_successes += pattern_data['overall_successes']
                
                print(f"  {pattern_name}: DSR={pattern_data['dsr']:.3f}, RSR={pattern_data['rsr']:.3f}, CSR={pattern_data['csr']:.3f}, Acc={pattern_data['accuracy']:.3f}")
            
            if total_runs > 0:
                overall_dsr = total_detection_successes / total_runs
                overall_rsr = total_reformulation_successes / total_runs
                overall_csr = total_compilation_successes / total_runs
                overall_accuracy = total_overall_successes / total_runs
                
                print(f"  Overall: DSR={overall_dsr:.3f}, RSR={overall_rsr:.3f}, CSR={overall_csr:.3f}, Acc={overall_accuracy:.3f}")

if __name__ == "__main__":
    main() 