import json
import os
import glob
import argparse
from collections import defaultdict
import numpy as np

def load_optimization_results():
    """Load the ground truth optimization results"""
    with open('optimization_results.json', 'r') as f:
        return json.load(f)

def categorize_problems_by_pattern():
    """Categorize problems by their nonlinearity patterns"""
    patterns = {
        'max': [
            'aircraft_monotone_max',
            'netasgn_nonlinear_max',
            'nltrans_nonlinear_max',
            'prod_nonlinear_max'
        ],
        'min': [
            'diet_problem_min_abs',
            'knapsack_problem_nonlinear_min_1',
            'knapsack_problem_nonlinear_min_2'
        ],
        'monotone': [
            'aircraft_monotone_max',
            'diet_problem_monotone'
        ],
        'abs_value': [
            'blend_problem_abs',
            'diet_problem_min_abs',
            'multi_nonlinear_abs',
            'netasgn_nonlinear_abs'
        ],
        'linear_fractional': [
            'blend_problem_frac',
            'diet_problem_nonlinear_frac',
            'netmcol_nonlinear_frac'
        ],
        'bilinear': [
            'media_selection_nonlinear_binbin',
            'media_selection_nonlinear_bincon',
            'netmcol_nonlinear_bincon',
            'nltrans_nonlinear_bincon',
            'prod_nonlinear_bincon',
            'revenue_maximization_nonlinear_bincon'
        ]
    }
    return patterns

def load_experiment_results(results_dir):
    """Load all experiment results from the specified results directory"""
    results = {}
    pattern_dirs = glob.glob(os.path.join(results_dir, '*'))
    
    # Mapping from directory names to problem names
    dir_to_problem_mapping = {
        'diet_problem_monotone': 'diet_problem_monotone',
        'aircraft_monotone_max': 'aircraft_monotone_max',
        'diet_problem_min_abs': 'diet_problem_min_abs',
        'prod_nonlinear_max': 'prod_nonlinear_max',
        'netmcol_nonlinear_bincon': 'netmcol_nonlinear_bincon',
        'nltrans_nonlinear_max': 'nltrans_nonlinear_max',
        'diet_problem_nonlinear_frac': 'diet_problem_nonlinear_frac',
        'netmcol_nonlinear_frac': 'netmcol_nonlinear_frac',
        'knapsack_problem_nonlinear_min_2': 'knapsack_problem_nonlinear_min_2',
        'knapsack_problem_nonlinear_min_1': 'knapsack_problem_nonlinear_min_1',
        'netasgn_nonlinear_max': 'netasgn_nonlinear_max',
        'revenue_maximization_nonlinear_bincon': 'revenue_maximization_nonlinear_bincon',
        'prod_nonlinear_bincon': 'prod_nonlinear_bincon',
        'nltrans_nonlinear_bincon': 'nltrans_nonlinear_bincon',
        'netasgn_nonlinear_abs': 'netasgn_nonlinear_abs',
        'multi_nonlinear_abs': 'multi_nonlinear_abs',
        'media_selection_nonlinear_binbin': 'media_selection_nonlinear_binbin',
        'media_selection_nonlinear_bincon': 'media_selection_nonlinear_bincon',
        'blend_problem_frac': 'blend_problem_frac',
        'blend_problem_abs': 'blend_problem_abs'
    }
    
    for pattern_dir in pattern_dirs:
        if os.path.isdir(pattern_dir):
            dir_name = os.path.basename(pattern_dir)
            summary_file = os.path.join(pattern_dir, 'experiment_summary.json')
            
            if os.path.exists(summary_file):
                with open(summary_file, 'r') as f:
                    data = json.load(f)
                    
                # Map directory name to problem name
                problem_name = dir_to_problem_mapping.get(dir_name, dir_name)
                results[problem_name] = data
    
    # Also load results from Some_older_results_I_guess directory
    older_results_dir = os.path.join(results_dir, 'Some_older_results_I_guess')
    if os.path.exists(older_results_dir):
        older_summary_file = os.path.join(older_results_dir, 'experiment_summary.json')
        if os.path.exists(older_summary_file):
            with open(older_summary_file, 'r') as f:
                older_data = json.load(f)
            
            # Extract individual problem results from the older results
            if 'problem_summary' in older_data:
                for problem_name, problem_data in older_data['problem_summary'].items():
                    # Create a new experiment result structure for each problem
                    problem_experiment_result = {
                        'experiment_info': {
                            'total_problems': 1,
                            'total_seeds': problem_data.get('total_seeds', 0),
                            'total_runs': problem_data.get('total_seeds', 0)
                        },
                        'error_statistics': {
                            'detection_errors': problem_data.get('detection_errors', 0),
                            'reformulation_errors': problem_data.get('reformulation_errors', 0),
                            'compilation_errors': problem_data.get('compilation_errors', 0),
                            'successful_runs': problem_data.get('successful_runs', 0)
                        },
                        'detection_accuracy': problem_data.get('detection_accuracy', {}),
                        'problem_summary': {
                            problem_name: problem_data
                        }
                    }
                    results[problem_name] = problem_experiment_result
    
    return results

def extract_objective_value(ground_truth_data):
    """Extract the objective value from ground truth data"""
    if ground_truth_data is None:
        return None
    
    # Try to get the optimal_point first
    optimal_point = ground_truth_data.get('optimal_point')
    
    if optimal_point is not None:
        if isinstance(optimal_point, (int, float)):
            return optimal_point
        elif isinstance(optimal_point, dict):
            # For problems with multiple variables, we might need to extract a specific value
            # For now, return the first numeric value we find
            for key, value in optimal_point.items():
                if isinstance(value, (int, float)):
                    return value
        elif isinstance(optimal_point, list):
            # For list of values, return the first numeric value
            for value in optimal_point:
                if isinstance(value, (int, float)):
                    return value
    
    # If optimal_point doesn't work, try objective_value
    objective_value = ground_truth_data.get('objective_value')
    if objective_value is not None:
        if isinstance(objective_value, (int, float)):
            return objective_value
        elif isinstance(objective_value, dict):
            # Look for a key that might contain the objective value
            for key, value in objective_value.items():
                if isinstance(value, (int, float)):
                    return value
    
    return None

def calculate_metrics_for_pattern(pattern_name, pattern_problems, experiment_results, ground_truth):
    """Calculate metrics for a specific nonlinearity pattern"""
    total_runs = 0
    detection_successes = 0
    reformulation_successes = 0
    compilation_successes = 0
    overall_successes = 0
    
    pattern_results = []
    
    for problem_name in pattern_problems:
        if problem_name in experiment_results:
            problem_data = experiment_results[problem_name]['problem_summary'].get(problem_name, {})
            
            total_seeds = problem_data.get('total_seeds', 0)
            successful_runs = problem_data.get('successful_runs', 0)
            detection_errors = problem_data.get('detection_errors', 0)
            reformulation_errors = problem_data.get('reformulation_errors', 0)
            compilation_errors = problem_data.get('compilation_errors', 0)
            
            # Get ground truth objective value
            ground_truth_obj = None
            if problem_name in ground_truth['results']:
                ground_truth_obj = extract_objective_value(ground_truth['results'][problem_name])
            
            # Calculate metrics for this problem
            runs_with_detection_success = total_seeds - detection_errors
            runs_with_reformulation_success = total_seeds - reformulation_errors
            runs_with_compilation_success = total_seeds - compilation_errors
            
            # Check objective value accuracy for successful runs
            correct_objective_runs = 0
            if successful_runs > 0 and ground_truth_obj is not None:
                objective_values = problem_data.get('objective_values', [])
                for obj_val in objective_values:
                    if obj_val is not None and ground_truth_obj is not None:
                        # Use tolerance of 0.001 for comparison
                        if abs(obj_val - ground_truth_obj) < 0.001:
                            correct_objective_runs += 1
            
            runs_with_overall_success = correct_objective_runs
            
            # Accumulate metrics
            total_runs += total_seeds
            detection_successes += runs_with_detection_success
            reformulation_successes += runs_with_reformulation_success
            compilation_successes += runs_with_compilation_success
            overall_successes += runs_with_overall_success
            
            pattern_results.append({
                'problem_name': problem_name,
                'total_seeds': total_seeds,
                'successful_runs': successful_runs,
                'detection_errors': detection_errors,
                'reformulation_errors': reformulation_errors,
                'compilation_errors': compilation_errors,
                'correct_objective_runs': correct_objective_runs,
                'ground_truth_obj': ground_truth_obj
            })
    
    # Calculate rates
    dsr = detection_successes / total_runs if total_runs > 0 else 0
    rsr = reformulation_successes / total_runs if total_runs > 0 else 0
    csr = compilation_successes / total_runs if total_runs > 0 else 0
    accuracy = overall_successes / total_runs if total_runs > 0 else 0
    
    return {
        'pattern_name': pattern_name,
        'total_runs': total_runs,
        'detection_successes': detection_successes,
        'reformulation_successes': reformulation_successes,
        'compilation_successes': compilation_successes,
        'overall_successes': overall_successes,
        'dsr': dsr,
        'rsr': rsr,
        'csr': csr,
        'accuracy': accuracy,
        'problem_results': pattern_results
    }

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Analyze performance metrics for LinearizeLLM experiments')
    parser.add_argument('--results-dir', type=str, default='data/results_o3', 
                       help='Directory containing experiment results (default: data/results_o3)')
    parser.add_argument('--output-file', type=str, default='performance_metrics_results.json',
                       help='Output file for detailed results (default: performance_metrics_results.json)')
    args = parser.parse_args()
    
    # Load data
    ground_truth = load_optimization_results()
    patterns = categorize_problems_by_pattern()
    experiment_results = load_experiment_results(args.results_dir)
    
    print(f"Performance Metrics Analysis for {args.results_dir}")
    print("=" * 50)
    
    all_results = []
    total_overall_runs = 0
    total_overall_detection_successes = 0
    total_overall_reformulation_successes = 0
    total_overall_compilation_successes = 0
    total_overall_successes = 0
    
    for pattern_name, pattern_problems in patterns.items():
        print(f"\nPattern: {pattern_name}")
        print("-" * 30)
        
        result = calculate_metrics_for_pattern(pattern_name, pattern_problems, experiment_results, ground_truth)
        all_results.append(result)
        
        print(f"Problems: {', '.join(pattern_problems)}")
        print(f"Total runs: {result['total_runs']}")
        print(f"Detection Success Rate (DSR): {result['dsr']:.3f} ({result['detection_successes']}/{result['total_runs']})")
        print(f"Reformulation Success Rate (RSR): {result['rsr']:.3f} ({result['reformulation_successes']}/{result['total_runs']})")
        print(f"Compilation Success Rate (CSR): {result['csr']:.3f} ({result['compilation_successes']}/{result['total_runs']})")
        print(f"Accuracy: {result['accuracy']:.3f} ({result['overall_successes']}/{result['total_runs']})")
        
        # Accumulate for overall metrics
        total_overall_runs += result['total_runs']
        total_overall_detection_successes += result['detection_successes']
        total_overall_reformulation_successes += result['reformulation_successes']
        total_overall_compilation_successes += result['compilation_successes']
        total_overall_successes += result['overall_successes']
    
    # Calculate overall metrics
    print("\n" + "=" * 50)
    print("OVERALL METRICS")
    print("=" * 50)
    print(f"Total runs across all patterns: {total_overall_runs}")
    print(f"Overall Detection Success Rate (DSR): {total_overall_detection_successes/total_overall_runs:.3f}")
    print(f"Overall Reformulation Success Rate (RSR): {total_overall_reformulation_successes/total_overall_runs:.3f}")
    print(f"Overall Compilation Success Rate (CSR): {total_overall_compilation_successes/total_overall_runs:.3f}")
    print(f"Overall Accuracy: {total_overall_successes/total_overall_runs:.3f}")
    
    # Save detailed results
    detailed_results = {
        'results_directory': args.results_dir,
        'pattern_results': all_results,
        'overall_metrics': {
            'total_runs': total_overall_runs,
            'detection_successes': total_overall_detection_successes,
            'reformulation_successes': total_overall_reformulation_successes,
            'compilation_successes': total_overall_compilation_successes,
            'overall_successes': total_overall_successes,
            'dsr': total_overall_detection_successes/total_overall_runs,
            'rsr': total_overall_reformulation_successes/total_overall_runs,
            'csr': total_overall_compilation_successes/total_overall_runs,
            'accuracy': total_overall_successes/total_overall_runs
        }
    }
    
    with open(args.output_file, 'w') as f:
        json.dump(detailed_results, f, indent=2)
    
    print(f"\nDetailed results saved to '{args.output_file}'")

if __name__ == "__main__":
    main()