import os
import math
import yaml
import time
import numpy as np
from typing import Dict, Any, List

# Import necessary MOCO components
from MOCO.problems import (
    BiObjectiveTSP, 
    MultiObjectiveKnapsack,
    TriObjectiveTSP, 
    BiObjectiveCVRP
)
from MOCO.evaluation import MOCOEvaluator
from benchmarking_helper import generate_ws_config, print_timing_details, run_algorithm_with_timing, analyze_benchmark_results
import multiprocessing
import torch
import logging

try:
    multiprocessing.set_start_method('spawn')
except RuntimeError:
    pass

# Import necessary libraries
from bopr_correct_implementation import GPUAcceleratedBOPRWrapper
from WS_LKH_DP import WSLKH # WSDP  # WSLKH_TriObjective
from PPLS_DC_optimized import PPLSDC, PPLSDCforCVRP


def run_flexible_benchmark(
    algorithm_classes,
    problem_types=['BiObjectiveTSP', 'MultiObjectiveKnapsack', 'TriObjectiveTSP', 'BiObjectiveCVRP'],
    problem_sizes={'BiObjectiveTSP': {'medium': {'n_cities': 50}}, 
                   'MultiObjectiveKnapsack': {'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5}},
                   'TriObjectiveTSP': {'small': {'n_cities': 20}},
                   'BiObjectiveCVRP': {'small': {'n_customers': 20}}},  # Added TriObjectiveTSP sizes
    num_runs=1,
    cuda_device=0,
    base_params=None
    ):
    """
    Flexible benchmark function that supports multiple algorithms and problems including BOPR
    
    Parameters:
    -----------
    algorithm_classes : dict
        Dictionary mapping algorithm names to algorithm classes
    problem_types : list
        List of problem types to benchmark
    problem_sizes : dict
        Dictionary of problem configurations
    num_runs : int
        Number of runs per configuration
    cuda_device : int
        CUDA device to use
    base_params : dict
        Base parameters for algorithms (will be overridden by algorithm-specific params)
    """
    # Use specified CUDA device
    device = torch.device(f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create timestamp for files
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    
    # Define problem mapping
    problem_map = {
        'BiObjectiveTSP': {'class': BiObjectiveTSP, 'ref_type': 'BiTSP'},
        'MultiObjectiveKnapsack': {'class': MultiObjectiveKnapsack, 'ref_type': 'BiKP'},
        'TriObjectiveTSP': {'class': TriObjectiveTSP, 'ref_type': 'TriTSP', 'n_objectives': 3},  # Added
        'BiObjectiveCVRP': {'class': BiObjectiveCVRP, 'ref_type': 'BiCVRP', 'n_objectives': 2}  # Added
    }
    
    # Set default base parameters if none provided
    if base_params is None:
        base_params = {
            'population_size': 100,
            'n_generations': 30,
            'verbose': True
        }
    
    # Store results
    results = {}
    config = {
        'metadata': {
            'timestamp': timestamp,
            'num_runs': num_runs,
            'device': device.type,
        },
        'results': {}
    }
    
    # Overall benchmark start time
    overall_start = time.time()
    
    # For each algorithm
    for algorithm_name, algorithm_class in algorithm_classes.items():
        print(f"\n{'-'*60}")
        print(f"Benchmarking algorithm: {algorithm_name}")
        print(f"{'-'*60}")
        
        # Algorithm-specific results
        results[algorithm_name] = {}
        config['results'][algorithm_name] = {}
        
        # For each problem type
        for problem_type in problem_types:
            if problem_type not in problem_map:
                print(f"Problem {problem_type} not defined. Skipping.")
                continue
                
            if problem_type not in problem_sizes:
                print(f"No sizes defined for {problem_type}. Skipping.")
                continue
                
            # Problem-specific results
            results[algorithm_name][problem_type] = {}
            config['results'][algorithm_name][problem_type] = {}
            
            # Get problem class
            problem_class = problem_map[problem_type]['class']
            ref_type = problem_map[problem_type]['ref_type']
            
            # Initialize temp evaluator for reference points
            temp_evaluator = MOCOEvaluator(reference_point=(1.0, 1.0))
            
            # For each problem size
            for size_name, problem_params in problem_sizes[problem_type].items():
                print(f"\nBenchmarking {problem_type} ({size_name})...")
                
                # Determine reference point size
                if problem_type == 'BiObjectiveTSP' or problem_type == 'TriObjectiveTSP':
                    ref_size = problem_params['n_cities']
                elif problem_type == 'MultiObjectiveKnapsack':  # MultiObjectiveKnapsack
                    ref_size = problem_params['n_items']
                else:                                               # BiObjectiveCVRP
                    ref_size = problem_params['n_customers']
                
                # Get standard reference points
                try:
                    standard_points = temp_evaluator.get_standard_points(
                        problem_type=ref_type,
                        problem_size=ref_size
                    )
                    
                    if standard_points and 'reference' in standard_points:
                        reference_point = standard_points['reference']
                        ideal_point = standard_points.get('ideal', (0, 0))
                    else:
                        
                        # Default reference points
                        reference_point = (35, 35) if problem_type == 'BiObjectiveTSP' else (20, 20)
                        ideal_point = (0, 0) if problem_type == 'BiObjectiveTSP' else (50, 50)
                except Exception as e:
                    print(f"Error getting standard points: {e}")
                    # Default reference points
                    reference_point = (35, 35) if problem_type == 'BiObjectiveTSP' else (20, 20)
                    ideal_point = (0, 0) if problem_type == 'BiObjectiveTSP' else (50, 50)
                
                print(f"Reference Point: {reference_point}")
                print(f"Ideal Point: {ideal_point}")
                
                # Create evaluator
                evaluator = MOCOEvaluator(
                    reference_point=reference_point, 
                    results_dir=f"benchmark_results_drl_{timestamp}"
                )
                
                # Disable parallel evaluation to avoid CUDA issues
                evaluator.parallel = False
                
                # Set algorithm parameters
                algorithm_params = base_params.copy()
                algorithm_params['reference_point'] = reference_point
                
                # Adjust parameters based on algorithm type and size
                if algorithm_name == "NSGA2qNEHVI":
                    # Add qNEHVI specific parameters
                    algorithm_params.update({
                        'batch_size': 5,
                        'max_iterations': 30,
                        'use_sparse_gp': True,
                        'n_inducing_points': 50,
                        'population_size': 100,
                        'n_generations': 100,
                        'verbose': True,
                        'matern_nu': 5 if ref_size==100 else 4,
                        'early_stopping_patience': 5 
                    })

                if algorithm_name == "BOPR" or "BOPR" in algorithm_name:
                    print("Inside BOPR conditional .....")
                    # Add BOPR specific parameters
                    # algorithm_params.update({
                    #     'n_initial': 10,
                    #     'n_iterations': 15, #if ref_size < 50 else 25,
                    #     'mc_samples': 64 if ref_size < 30 else 2,
                    #     'lr': 0.5,
                    #     'temperature': 0.8,
                    #     'sparse_gp': True if ref_size > 30 else False,
                    #     'inducing_points': 2, #min(30, max(30, ref_size)), #min(60, max(30, ref_size))#min(100, max(30, ref_size))
                    # })
                    # For BiKP
                    # algorithm_params.update({
                    #     'n_initial': 50 if ref_size <= 100 else 100,
                    #     'n_iterations': 15 if ref_size <= 100 else 20,
                    #     'mc_samples': 64,
                    #     'lr': 0.05,
                    #     'temperature': 0.5 if ref_size <= 50 else (0.3 if ref_size <= 100 else 0.2),
                    #     'sparse_gp': False,
                    #     'inducing_points': min(50, max(30, ref_size))#min(100, max(30, ref_size))
                    # })
                    # test
                    algorithm_params.update({
                        'n_initial': 50, #10,
                        'n_iterations': 30,
                        'mc_samples': 32,
                        'lr': 0.05,
                        'temperature': 0.3,
                        'sparse_gp': False,
                        'inducing_points': 100, 
                    })

                if algorithm_name == "PPLS-DC" in algorithm_name:
                    # Add BOPR specific parameters
                    # algorithm_params.update({
                    #     'num_processes': 4,          # Reduced for better efficiency
                    #     'max_iterations': 200,       # Reduced for faster runtime
                    #     'cooperation_frequency': 10, # How often processes share solutions
                    #     'decomposition': 'objective' # Decomposition strategy
                    # })
                    algorithm_params = {
                           'num_processes': 4,
                            'max_iterations': 200,
                            'cooperation_frequency': 10,
                            'archive_size': 100,
                            'neighborhood_size': 100,
                            'topology': 'ring',
                            'use_master': True,
                            'adjustment_frequency': 20,
                            'patience': 999999
                        }

                if algorithm_name == "WS-DP" or "WSDP" in algorithm_name:
                    # Add BOPR specific parameters
                    algorithm_params.update({
                        'num_weights': 101 #80
                    })
                
                if algorithm_name == "WS-LKH" or "WSLKH" in algorithm_name:
                    # Add BOPR specific parameters
                    algorithm_params.update({
                        'num_weights': 80
                    })
                 
                elif algorithm_name == "WS-LKH-Tri":
                    # For tri-objective weighted sum
                    algorithm_params.update({
                        'num_weights': 210  # 5^3 = 125 weight combinations
                    })
                # Start problem timer
                problem_start_time = time.time()
                
                try:
                    print(f"Running {algorithm_name} on {problem_type} ({size_name}) with {num_runs} runs...")
                    
                    # Run algorithm evaluation
                    result = evaluator.evaluate_algorithm(
                        algorithm_class=algorithm_class,
                        problem_class=problem_class,
                        algorithm_name=algorithm_name,
                        parameters=algorithm_params,
                        problem_params=problem_params,
                        num_runs=num_runs
                    )
                    
                    # Calculate runtime
                    problem_runtime = time.time() - problem_start_time
                    
                    # Store results
                    benchmark_result = {
                        'status': 'success',
                        'runtime': float(result.runtime),
                        'local_runtime': float(problem_runtime),
                        'hypervolume': float(result.hypervolume),
                        'num_nondominated': int(result.num_nondominated)
                    }
                    
                    results[algorithm_name][problem_type][size_name] = benchmark_result
                    config['results'][algorithm_name][problem_type][size_name] = benchmark_result
                    
                    # Display results
                    print(f"\nResults for {algorithm_name} on {problem_type} ({size_name}):")
                    print(f"  Runtime: {result.runtime:.2f} seconds")
                    print(f"  Local Runtime: {problem_runtime:.2f} seconds")
                    print(f"  Hypervolume: {result.hypervolume:.4f}")
                    print(f"  Non-dominated Solutions: {result.num_nondominated}")
                    
                except Exception as e:
                    print(f"Error benchmarking {algorithm_name} on {problem_type} ({size_name}): {e}")
                    import traceback
                    traceback.print_exc()
                    
                    error_result = {'status': 'error', 'error': str(e)}
                    results[algorithm_name][problem_type][size_name] = error_result
                    config['results'][algorithm_name][problem_type][size_name] = error_result
    
    # Calculate overall time
    overall_time = time.time() - overall_start
    config['metadata']['overall_time'] = overall_time
    
    # Save configuration and results
    os.makedirs("configs", exist_ok=True)
    config_path = f"configs/benchmark_results_drl_{timestamp}.yaml"
    
    try:
        with open(config_path, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
        print(f"\nResults saved to: {config_path}")
    except Exception as e:
        print(f"Error saving config: {e}")
    
    # Print summary
    print("\n" + "="*50)
    print("BENCHMARK SUMMARY")
    print("="*50)
    print(f"Overall time: {overall_time:.1f} seconds")
    
    for alg_name, alg_results in results.items():
        print(f"\n{alg_name}:")
        for prob_name, prob_results in alg_results.items():
            print(f"  {prob_name}:")
            for size_name, result in prob_results.items():
                if result.get('status') == 'success':
                    print(f"    {size_name.upper()}: Runtime: {result['runtime']:.2f}s, "
                         f"HV: {result['hypervolume']:.4f} "
                         f"({result['num_nondominated']} solutions)")
                else:
                    print(f"    {size_name.upper()}: Error: {result.get('error', 'Unknown error')}")
    
    return results, config


# Example usage
if __name__ == "__main__":
    # Define algorithms to benchmark including BOPR
    algorithms = {
        # "WS-LKH-Tri": WSLKH_TriObjective
        # "WS-DP":  WSDP,
        # "WS-LKH":  WSLKH,
        # "PPLS-DC": PPLSDC, #PPLSDC,
        # "PPLS-DC": PPLSDCforCVRP,
        # "NSGA2": NSGA2,
        # "NSGA2qNEHVI": NSGA2qNEHVI,
        # "MOBOqParEGO": MOBOqParEGO,
        "BOPR": GPUAcceleratedBOPRWrapper
    }
    
    # Define problem sizes
    problem_sizes = {
        'BiObjectiveTSP': {
            'small': {'n_cities': 20},
            'medium': {'n_cities': 50},
            'large': {'n_cities': 100}
        },
        'MultiObjectiveKnapsack': {
        #     'small': {'n_items': 50, 'n_objectives': 2, 'capacity': 12.5},
            'medium': {'n_items': 100, 'n_objectives': 2, 'capacity': 25},
            'large': {'n_items': 200, 'n_objectives': 2, 'capacity': 25}
        },
        'TriObjectiveTSP': {  # Add this section
                'small': {'n_cities': 20},
                'medium': {'n_cities': 50},
                'large': {'n_cities': 100}
        },
        'BiObjectiveCVRP': {  # Add this section
                'small': {'n_customers': 20},
                'medium': {'n_customers': 50},
                'large': {'n_customers': 100}
        }
    }
    
    # Base parameters for all algorithms
    base_params = {}
    
    # Run the benchmark
    results, config = run_flexible_benchmark(
        algorithm_classes=algorithms,
        problem_sizes=problem_sizes,
        num_runs=1,
        cuda_device=2,
        base_params=base_params
    )