import os
import math
import yaml
import time
import numpy as np
from typing import Dict, Any, List

# Import necessary MOCO components
from MOCO.problems import (
    BiObjectiveTSP, 
    MultiObjectiveKnapsack,
    TriObjectiveTSP,
    BiObjectiveCVRP
)

from MOCO.evaluation import MOCOEvaluator
from benchmarking_helper import generate_ws_config, print_timing_details, run_algorithm_with_timing, analyze_benchmark_results
import multiprocessing
import torch
import logging

try:
    multiprocessing.set_start_method('spawn')
except RuntimeError:
    pass

# Import necessary libraries
from mobo_qNEHVI_optimized_on_discrete_benchmark_June import NSGA2qNEHVI
from mobo_nsga2_on_discrete_benchmarks import NSGA2
from mobo_qparegoGA_optimized_on_discrete_benchmarks_general import MOBOqParEGO
# from mobo_qparegoGA_optimized_on_discrete_benchmarks_general import BiKPMOBOWrapper


def run_flexible_benchmark(
    algorithm_classes,
    problem_types=['BiObjectiveTSP', 'MultiObjectiveKnapsack', 'TriObjectiveTSP', 'BiObjectiveCVRP'],
    problem_sizes={'BiObjectiveTSP': {'medium': {'n_cities': 50}}, 
                   'MultiObjectiveKnapsack': {'medium': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5}},
                   'BiObjectiveCVRP': {'small': {'n_customers': 20}}},  # Added TriObjectiveTSP sizes},
    num_runs=1,
    cuda_device=0,
    base_params=None
    ):
    """
    Flexible benchmark function that supports multiple algorithms and problems
    
    Parameters:
    -----------
    algorithm_classes : dict
        Dictionary mapping algorithm names to algorithm classes
    problem_types : list
        List of problem types to benchmark
    problem_sizes : dict
        Dictionary of problem configurations
    num_runs : int
        Number of runs per configuration
    cuda_device : int
        CUDA device to use
    base_params : dict
        Base parameters for algorithms (will be overridden by algorithm-specific params)
    """
    # Use specified CUDA device
    device = torch.device(f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create timestamp for files
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    
    # Define problem mapping
    problem_map = {
        'BiObjectiveTSP': {'class': BiObjectiveTSP, 'ref_type': 'BiTSP'},
        'MultiObjectiveKnapsack': {'class': MultiObjectiveKnapsack, 'ref_type': 'BiKP'},
         'TriObjectiveTSP': {'class': TriObjectiveTSP, 'ref_type': 'TriTSP'},  # Add this line
         'BiObjectiveCVRP': {'class': BiObjectiveCVRP, 'ref_type': 'BiCVRP', 'n_objectives': 2}  # Added
    }
    
    # Set default base parameters if none provided
    if base_params is None:
        base_params = {
            'population_size': 100,
            'n_generations': 30,
            'verbose': True
        }
    
    # Store results
    results = {}
    config = {
        'metadata': {
            'timestamp': timestamp,
            'num_runs': num_runs,
            'device': device.type,
        },
        'results': {}
    }
    
    # Overall benchmark start time
    overall_start = time.time()
    
    # For each algorithm
    for algorithm_name, algorithm_class in algorithm_classes.items():
        print(f"\n{'-'*60}")
        print(f"Benchmarking algorithm: {algorithm_name}")
        print(f"{'-'*60}")
        
        # Algorithm-specific results
        results[algorithm_name] = {}
        config['results'][algorithm_name] = {}
        
        # For each problem type
        for problem_type in problem_types:
            if problem_type not in problem_map:
                print(f"Problem {problem_type} not defined. Skipping.")
                continue
                
            if problem_type not in problem_sizes:
                print(f"No sizes defined for {problem_type}. Skipping.")
                continue
                
            # Problem-specific results
            results[algorithm_name][problem_type] = {}
            config['results'][algorithm_name][problem_type] = {}
            
            # Get problem class
            problem_class = problem_map[problem_type]['class']
            ref_type = problem_map[problem_type]['ref_type']
            
            # Initialize temp evaluator for reference points
            temp_evaluator = MOCOEvaluator(reference_point=(1.0, 1.0))
            
            # For each problem size
            for size_name, problem_params in problem_sizes[problem_type].items():
                print(f"\nBenchmarking {problem_type} ({size_name})...")
                
                # Determine reference point size
                if problem_type == 'BiObjectiveTSP' or problem_type == 'TriObjectiveTSP':
                    ref_size = problem_params['n_cities']
                elif problem_type == 'MultiObjectiveKnapsack':  # MultiObjectiveKnapsack
                    ref_size = problem_params['n_items']
                else:                                               # BiObjectiveCVRP
                    ref_size = problem_params['n_customers']
                
                # Get standard reference points
                try:
                    standard_points = temp_evaluator.get_standard_points(
                        problem_type=ref_type,
                        problem_size=ref_size
                    )
                    
                    if standard_points and 'reference' in standard_points:
                        reference_point = standard_points['reference']
                        ideal_point = standard_points.get('ideal', (0, 0))
                    elif problem_type == 'TriObjectiveTSP':
                        # For 3-objective problems, use 3D reference point
                        reference_point = (35, 35, 35)  # Adjust based on your needs
                        ideal_point = (0, 0, 0)
                    else:
                        # Default reference points
                        reference_point = (35, 35) if problem_type == 'BiObjectiveTSP' else (20, 20)
                        ideal_point = (0, 0) if problem_type == 'BiObjectiveTSP' else (50, 50)
                except Exception as e:
                    print(f"Error getting standard points: {e}")
                    # Default reference points
                    reference_point = (35, 35) if problem_type == 'BiObjectiveTSP' else (20, 20)
                    ideal_point = (0, 0) if problem_type == 'BiObjectiveTSP' else (50, 50)
                
                # # After getting the standard points:
                # if problem_type == 'MultiObjectiveKnapsack':
                #     # Get standard points
                #     knapsack_standard_points = temp_evaluator.get_standard_points('BiKP', ref_size)
                    
                #     # Since we're transforming maximization to minimization, 
                #     # we need to negate the reference and ideal points
                #     if knapsack_standard_points and 'reference' in knapsack_standard_points:
                #         # Negate for maximization->minimization transformation
                #         reference_point = tuple(-1 * np.array(knapsack_standard_points['reference']))
                #         ideal_point = tuple(-1 * np.array(knapsack_standard_points.get('ideal', (50, 50))))
                #     else:
                #         break

                print(f"Reference Point: {reference_point}")
                print(f"Ideal Point: {ideal_point}")
                
                # Create evaluator
                evaluator = MOCOEvaluator(
                    reference_point=reference_point, 
                    results_dir=f"benchmark_results_{timestamp}"
                )
                
                # Disable parallel evaluation to avoid CUDA issues
                evaluator.parallel = False
                
                # Set algorithm parameters
                algorithm_params = base_params.copy()
                algorithm_params['reference_point'] = reference_point
                
                # Adjust parameters based on problem type and size
                if algorithm_name == "NSGA2qNEHVI":
                    print("inside qNEHVI...")
                    # Add qNEHVI specific parameters - BiTSP
                    # algorithm_params.update({
                    #     'batch_size': 5,
                    #     'max_iterations': 30,
                    #     'use_sparse_gp': True,
                    #     'n_inducing_points': 50,#50,
                    #     'population_size': 100,#100,
                    #     'n_generations': 100,#30,
                    #     'verbose': True,
                    #     'matern_nu': 5 if ref_size==100 else 4,
                    #     'early_stopping_patience': 5,
                    #      'cuda_device': cuda_device 
                    # })
                    # algorithm_params.update({
                    #     'batch_size': 5,
                    #     'max_iterations': 10 if ref_size == 50 else 10 * (ref_size)//50,
                    #     'use_sparse_gp': True,
                    #     'n_inducing_points': ref_size + 20,
                    #     'population_size': ref_size,
                    #     'n_generations': 20 if ref_size == 50 else 20 * (ref_size)//50,
                    #     'verbose': True,
                    #     'matern_nu': 2.5,
                    #     'early_stopping_patience': 5,
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.2,
                    #     'cuda_device': cuda_device
                    # })
                    # BiKP 50, 100
                    # algorithm_params.update({
                    #     'batch_size': 10,
                    #     'max_iterations': 30,
                    #     'use_sparse_gp': False,#True,
                    #     'n_inducing_points': 100,
                    #     'population_size': 200, #int(ref_size*2.5), 
                    #     'n_generations': 50,
                    #     'verbose': True,
                    #     'matern_nu': 2.5,
                    #     'early_stopping_patience': 10,#7
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.3,
                    #     'cuda_device': cuda_device
                    # })
                    # For TriTSP
                    # algorithm_params.update({
                    #     'batch_size': 5,
                    #     'max_iterations': 40, #ref_size // 2, #10, # iterations in NSGA
                    #     'population_size': ref_size, #30, # iterations in NSGA
                    #     'n_generations': 100, #20, # iterations in NSGA
                    #     'use_sparse_gp': False,  # Full GP is fine for small problems
                    #     'n_inducing_points': 80,
                    #     'matern_nu': 1.5,  # Smoother for small search space
                    #     'early_stopping_patience': 100,
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.3,  # Lower mutation for small problems
                    #     'verbose': True,
                    #     'cuda_device': cuda_device
                    # })
                    algorithm_params.update({
                        # Core parameters
                        'batch_size': 5,  # Larger batch for 3D exploration
                        'max_iterations': 10,  # More iterations for 3D
                        
                        # GP parameters
                        'use_sparse_gp': False,  # Sparse GP for efficiency
                        'n_inducing_points': 100,  # Adequate for TriTSP-20
                        'matern_nu': 2.5,  # Standard choice (1.5 is too smooth, 2.5 better)
                        
                        # Early stopping
                        'early_stopping_patience': 1000,  # Stop after 5 iterations without improvement
                        # 'early_stopping_threshold': 1e-3,  # Minimum improvement threshold
                        
                        # NSGA-II parameters (used for BiTSP, ignored for TriTSP with random search)
                        'population_size': 80,  # Only matters if using NSGA-II
                        'n_generations': 80,  # Drastically reduced! Only matters if using NSGA-II
                        'crossover_prob': 0.9,
                        'mutation_prob': 0.4, #0.3 best for cvrp20
                        
                        # Random search parameters for TriTSP (add these if not already in code)
                        # 'n_random_candidates': 2000,  # For qNEHVI-guided random search
                        
                        # General
                        'verbose': True,
                        'cuda_device': cuda_device
                    })

                    

                elif algorithm_name == "NSGA2":
                    # Add NSGA2 specific parameters (I used the same for both BiTSP & BiKP; except minor change in n_generations.)
                    # algorithm_params.update({
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.1,
                    #     'population_size': 100,
                    #     'n_generations': 50 if ref_size < 50 else 100, #50 if ref_size <= 50 else 100: for BiTSP
                    #     'tournament_size':3, # knapsack
                    #     'verbose': True
                    # })
                    # Parameters used for TriTSP
                    # algorithm_params.update({
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.2,
                    #     'population_size': 200,
                    #     'n_generations': 50 if ref_size <= 20 else 100 if 20 < ref_size <= 50 else 150, #50 if ref_size <= 50 else 200, # for BiTSP/TriTSP
                    #     'tournament_size':4, # knapsack
                    #     'verbose': True
                    # })
                    algorithm_params.update({
                        'crossover_prob': 0.9,
                        'mutation_prob': 0.1,
                        'population_size': 100,
                        'n_generations': 50 if ref_size <= 50 else 100, #50 if ref_size <= 50 else 100: for BiTSP
                        'tournament_size': 4, # knapsack
                        'verbose': True
                    })
                    print("Inside NSGA2 params: ", algorithm_params)

                elif algorithm_name == "MOBOqParEGO":
                    # Add qParEGO specific parameters (based on test_mobo_problems)
                    # Wrap parameters in 'config' as shown in test_mobo_problems
                    base_size = 20 if problem_type=="BiObjectiveTSP" else 50 # knapsack

                    # qparego_config = {
                    #     'n_initial': 20,
                    #     'n_iterations': 10 if ref_size == 20 else 20,
                    #     'q': int(4 * (1 + math.log(ref_size / base_size))),
                    #     'pop_size': 50 if ref_size == 20 else 100,
                    #     'n_generations': 50,
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.2,
                    #     'tournament_size': 5,
                    #     'matern_nu': 2.5 if ref_size == 20 else 5,
                    #     'use_sparse_gp': True,
                    #     'model_rebuild_interval': 5
                    # }
                    # qparego_config = {
                    #     'n_initial': 100,
                    #     'n_iterations': 12,
                    #     'q':  min(15, int(6 + 4 * math.log(ref_size / 50))),
                    #     'pop_size': ref_size//2,
                    #     'n_generations': min(10, ref_size//5),
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.2,
                    #     'tournament_size': 3,
                    #     'matern_nu': 2.5,
                    #     'use_sparse_gp': True,
                    #     'model_rebuild_interval': 5
                    # }
                    
                    # for bitsp100 use below
                    # qparego_config = {
                    #         'n_initial': 100,              # More initial samples for better coverage
                    #         'n_iterations': 20,            # More iterations for BO to learn
                    #         'q': 15,                        # Standard batch size for 2D problems
                    #         'pop_size': 80,                # Good balance for TSP
                    #         'n_generations': 80,           # Sufficient for GA convergence
                    #         'crossover_prob': 0.95,         # High crossover is good for TSP
                    #         'mutation_prob': 0.03,         # Lower mutation (TSP structure is fragile)
                    #         'tournament_size': 7,          # Good selection pressure
                    #         'matern_nu': 2.5,              # Standard smooth kernel
                    #         'use_sparse_gp': True,         # Important for scalability # Keep it True gives better results with faster time!
                    #         'model_rebuild_interval': 4    # Rebuild often enough to track progress
                    #     }

                    # 
                    # qparego_config = {
                    #         'n_initial': 250,              # More initial samples for better coverage
                    #         'n_iterations': 15,            # More iterations for BO to learn
                    #         'q': 10,                        # Standard batch size for 2D problems
                    #         'pop_size': 80,                # Good balance for TSP
                    #         'n_generations': 80,           # Sufficient for GA convergence
                    #         'crossover_prob': 0.95,         # High crossover is good for TSP
                    #         'mutation_prob': 0.03,         # Lower mutation (TSP structure is fragile)
                    #         'tournament_size': 3,          # Good selection pressure
                    #         'matern_nu': 2.5,              # Standard smooth kernel
                    #         'use_sparse_gp': True,         # Important for scalability # Keep it True gives better results with faster time!
                    #         'model_rebuild_interval': 4    # Rebuild often enough to track progress
                    #     }


                    # Final BiKP config
                    # qparego_config = {
                    #     'n_initial': min(50, ref_size), #ref_size*2,
                    #     'n_iterations': ref_size//4 if ref_size < 30 else min(ref_size//8, 10),
                    #     'q':  min(15, int(6 + 4 * math.log(ref_size / 50))),
                    #     'pop_size': ref_size,
                    #     'n_generations': ref_size,
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.3,
                    #     'tournament_size': 3,
                    #     'matern_nu': 1.5,
                    #     'use_sparse_gp': False,
                    #     'model_rebuild_interval': 5
                    # }

                    # Final TriTSP config, TriTSp50
                    # qparego_config = {
                    #     'n_initial': min(40, max(20, ref_size)),
                    #     'n_iterations': 10, #max(15, min(30, ref_size//2)),  # More iterations
                    #     'q': 16, #min(20, int(6 * 3 * (1 + math.log(max(ref_size, 20) / 50)))),  # Scale with objectives n_samples basically
                    #     'pop_size': min(100, max(50, ref_size)),
                    #     'n_generations': min(50, max(30, ref_size//2)),  # Cap at 50
                    #     'crossover_prob': 0.9,
                    #     'mutation_prob': 0.3,
                    #     'tournament_size': 3,
                    #     'matern_nu': 2.5,  # Smoother GP
                    #     'use_sparse_gp': True,  # For efficiency
                    #     'model_rebuild_interval': 5  # More frequent updates
                    # }
                    qparego_config = {
                        'n_initial': min(40, max(20, ref_size)),
                        'n_iterations': 10, #max(15, min(30, ref_size//2)),  # More iterations
                        'q': 12, #min(20, int(6 * 3 * (1 + math.log(max(ref_size, 20) / 50)))),  # Scale with objectives n_samples basically
                        'pop_size': min(100, max(50, ref_size)),
                        'n_generations': min(50, max(30, ref_size//2)),  # Cap at 50
                        'crossover_prob': 0.9,
                        'mutation_prob': 0.3,
                        'tournament_size': 3,
                        'matern_nu': 2.5,  # Smoother GP
                        'use_sparse_gp': True,  # For efficiency
                        'model_rebuild_interval': 5  # More frequent updates
                    }
                    
                    print(qparego_config)
                    # Replace algorithm_params with the nested config format
                    algorithm_params = {'config': qparego_config}
                    algorithm_params['config']['reference_point'] = reference_point
                    
                # Start problem timer
                problem_start_time = time.time()
                
                try:
                    print(f"Running {algorithm_name} on {problem_type} ({size_name}) with {num_runs} runs...")
                    
                    # Run algorithm evaluation
                    result = evaluator.evaluate_algorithm(
                        algorithm_class=algorithm_class,
                        problem_class=problem_class,
                        algorithm_name=algorithm_name,
                        parameters=algorithm_params,
                        problem_params=problem_params,
                        num_runs=num_runs
                    )
                    
                    # Calculate runtime
                    problem_runtime = time.time() - problem_start_time
                    
                    # Store results
                    benchmark_result = {
                        'status': 'success',
                        'runtime': float(result.runtime),
                        'local_runtime': float(problem_runtime),
                        'hypervolume': float(result.hypervolume),
                        'num_nondominated': int(result.num_nondominated)
                    }
                    
                    results[algorithm_name][problem_type][size_name] = benchmark_result
                    config['results'][algorithm_name][problem_type][size_name] = benchmark_result
                    
                    # Display results
                    print(f"\nResults for {algorithm_name} on {problem_type} ({size_name}):")
                    print(f"  Runtime: {result.runtime:.2f} seconds")
                    print(f"  Local Runtime: {problem_runtime:.2f} seconds")
                    print(f"  Hypervolume: {result.hypervolume:.4f}")
                    print(f"  Non-dominated Solutions: {result.num_nondominated}")
                    
                    # ADD THIS: Explicit cleanup after each problem size
                    print(f"Cleaning up after {problem_type} ({size_name})...")
                    
                    # Force garbage collection
                    import gc
                    gc.collect()
                    
                    # Clear CUDA cache if available
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        torch.cuda.synchronize()
                    
                    # Small delay to ensure cleanup
                    time.sleep(2)

                except Exception as e:
                    print(f"Error benchmarking {algorithm_name} on {problem_type} ({size_name}): {e}")
                    import traceback
                    traceback.print_exc()
                    
                    error_result = {'status': 'error', 'error': str(e)}
                    results[algorithm_name][problem_type][size_name] = error_result
                    config['results'][algorithm_name][problem_type][size_name] = error_result
    
    # Calculate overall time
    overall_time = time.time() - overall_start
    config['metadata']['overall_time'] = overall_time
    
    # Save configuration and results
    os.makedirs("configs", exist_ok=True)
    config_path = f"configs/benchmark_results_{timestamp}.yaml"
    
    try:
        with open(config_path, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
        print(f"\nResults saved to: {config_path}")
    except Exception as e:
        print(f"Error saving config: {e}")
    
    # Print summary
    print("\n" + "="*50)
    print("BENCHMARK SUMMARY")
    print("="*50)
    print(f"Overall time: {overall_time:.1f} seconds")
    
    for alg_name, alg_results in results.items():
        print(f"\n{alg_name}:")
        for prob_name, prob_results in alg_results.items():
            print(f"  {prob_name}:")
            for size_name, result in prob_results.items():
                if result.get('status') == 'success':
                    print(f"    {size_name.upper()}: Runtime: {result['runtime']:.2f}s, "
                         f"HV: {result['hypervolume']:.4f} "
                         f"({result['num_nondominated']} solutions)")
                else:
                    print(f"    {size_name.upper()}: Error: {result.get('error', 'Unknown error')}")
    
    return results, config



# Example usage
if __name__ == "__main__":
    # Define algorithms to benchmark
    algorithms = {
        # "NSGA2": NSGA2,
        "NSGA2qNEHVI": NSGA2qNEHVI,
        # "MOBOqParEGO": MOBOqParEGO #, BiKPMOBOWrapper,
    }
    
    # Define problem sizes
    problem_sizes = {
        'BiObjectiveTSP': {
            'small': {'n_cities': 20},
        # #     # 'medium': {'n_cities': 50},
        #     'large': {'n_cities': 100}
        },
        # 'MultiObjectiveKnapsack': {
            # 'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5},
            # 'medium': {'n_items': 100, 'n_objectives': 2, 'capacity': 25.0},
            # 'large': {'n_items': 200, 'n_objectives': 2, 'capacity': 25.0}
        # },
        # 'TriObjectiveTSP': {  # Add this section
                # 'small': {'n_cities': 20},
                # 'medium': {'n_cities': 50},
                # 'large': {'n_cities': 100}
    # },
    #  'BiObjectiveCVRP': {  # Add this section
    #             # 'small': {'n_customers': 20},
    #             'medium': {'n_customers': 50},
    #             # 'large': {'n_customers': 100}
    #     }
    }
    
    # Base parameters for all algorithms
    base_params = {}
    
    # Run the benchmark
    results, config = run_flexible_benchmark(
        algorithm_classes=algorithms,
        problem_sizes=problem_sizes,
        num_runs=1,
        cuda_device=4,
        base_params=base_params
    )
    