"""
Compact script to run enhanced wrapper with tracking and generate all plots

Usage:
    python run_tracking_analysis.py

Choose problem type by changing PROBLEM_TYPE variable:
    - 'BiKP': Bi-objective Knapsack
    - 'BiTSP': Bi-objective TSP
"""

import numpy as np
import matplotlib.pyplot as plt

import sys

from enhanced_wrapper_with_tracking import CachedAdvancedBiKPWrapperWithTracking, MetricsTracker
from plotting.plot_tracking_metrics import quick_analysis

# Import problems from MOCO
from MOCO.problems import BiObjectiveTSP, MultiObjectiveKnapsack


def setup_plot_style():
    """Setup clean ICML-style plotting with pastel colors"""
    plt.style.use('seaborn-v0_8-paper')

    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],
        'font.size': 10,
        'axes.labelsize': 11,
        'axes.titlesize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 13,
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 1.0,
        'lines.linewidth': 2.0,
        'lines.markersize': 8,
    })


if __name__ == "__main__":
    # Setup plotting style
    setup_plot_style()

    print("="*80)
    print(" Running Enhanced Wrapper with Tracking ".center(80))
    print("="*80 + "\n")

    # ========================================================================
    # Configuration
    # ========================================================================
    PROBLEM_TYPE = 'BiTSP'  # Change to 'BiTSP' for bi-objective TSP
    NUM_RUNS = 2           # Number of runs for statistical reliability

    # ========================================================================
    # Configure wrapper parameters
    # ========================================================================
    kwargs = {
        # 'n_weight_vectors': 5,           # Number of weight vectors to explore
        # 'max_iterations': 30,            # Base iterations (scaled by problem size)
        # 'use_correlation_decomposition': False,
        # 'use_elite_decomposition': False,
        # 'use_metric_decomposition': True,
        # 'decomposition_size': 15,
        # 'overlap': 6,
        # 'learning_rate': 0.1,
        # 'temperature': 1.0,
        # 'ucb_coeff': 2.0,
        # 'samples_per_subproblem': 100,
        'learning_rate': 0.05,  # Reduced for stability
        'ucb_coefficient': 3.0,  # Less exploration
        'temperature': 1.0,#1.0,  # Ensure it starts at 1
        'temp_decay': 0.98, #0.98,  # Much slower decay
        'hybrid_ratio': 0.7,
        'adaptive_hybrid': True,
        'decomposition_size': 25, #40, 25(best 100 scale), #ref_size//3,  # 15 Or 15 for size 100 problems
        'overlap': 10, #10,# 12(best), #16, 10, 8, #ref_size//3 -7 ,  # 6 Larger initial overlap
        'use_adaptive_weights': True,
        # 'adaptive_strategy': 'geometric',  # Use simple approach
        'max_iterations': 200,
        'nb_rounds': 25, #5 | 20 gave best till now for laarge scale
        'patience': 100, #20, # 20,
        'use_lagrangian': True,
        'use_ftrl': True,
        'dual_step_size': 0.2,  # Much larger!
        'use_accelerated_dual': False,
        'use_diminishing_overlap': True,
        'overlap_decay_rate': 0.1,  # Slower decay
        # 'use_selective_sensitivity': True
        'use_learned_operators': True,
        'n_weight_vectors': 20, #ref_size + 10,  # For wrapper
        # 'max_weights': 20

    }

    # ========================================================================
    # Run multiple times for statistical reliability
    # ========================================================================
    all_trackers = []
    all_pareto_fronts = []

    for run_idx in range(NUM_RUNS):
        print(f"\n{'='*80}")
        print(f" RUN {run_idx + 1}/{NUM_RUNS} ".center(80))
        print(f"{'='*80}\n")

        # Create fresh problem instance for each run
        if PROBLEM_TYPE == 'BiKP':
            print(f"Creating BiObjective Knapsack problem (seed={run_idx})...")
            np.random.seed(42 + run_idx)  # Different seed per run
            problem = MultiObjectiveKnapsack(
                n_items=50,
                n_objectives=2,
                capacity=12.5
            )
            total_weight = np.sum(problem.weights)
            problem.capacity = problem.capacity * total_weight / 100
            print(f"  - Items: {problem.n_items}")
            print(f"  - Capacity: {problem.capacity:.2f}")

        elif PROBLEM_TYPE == 'BiTSP':
            print(f"Creating BiObjective TSP problem (seed={run_idx})...")
            np.random.seed(42 + run_idx)  # Different seed per run
            ncities = 50
            problem = BiObjectiveTSP(n_cities=ncities)
            print(f"  - Cities: {problem.n_cities}")

        else:
            raise ValueError(f"Unknown problem type: {PROBLEM_TYPE}")

        print(f"  - Objectives: {problem.num_objectives}\n")

        # Run optimization with tracking
        wrapper = CachedAdvancedBiKPWrapperWithTracking(problem, **kwargs)
        pareto_front = wrapper.run()

        print(f"\n✓ Run {run_idx + 1} complete!")
        print(f"  - Pareto front size: {len(pareto_front)}")

        # Store results
        all_trackers.append(wrapper.get_tracker())
        all_pareto_fronts.append(pareto_front)

    # ========================================================================
    # Aggregate results across all runs
    # ========================================================================
    print(f"\n{'='*80}")
    print(" AGGREGATING RESULTS ACROSS ALL RUNS ".center(80))
    print(f"{'='*80}\n")

    # Create aggregated tracker
    aggregated_tracker = MetricsTracker()

    # Find max iterations across all runs
    max_iterations = max(
        max(t.adaptivity_metrics.keys()) if t.adaptivity_metrics else 0
        for t in all_trackers
    )

    # Aggregate metrics across runs
    for iteration in range(max_iterations + 1):
        adaptivity_list = []
        learning_list = []
        conflict_list = []

        for tracker in all_trackers:
            if iteration in tracker.adaptivity_metrics:
                adaptivity_list.append(tracker.adaptivity_metrics[iteration])
            if iteration in tracker.sequential_learning_metrics:
                learning_list.append(tracker.sequential_learning_metrics[iteration])
            if iteration in tracker.conflict_metrics:
                conflict_list.append(tracker.conflict_metrics[iteration])

        # Average adaptivity metrics
        if adaptivity_list:
            aggregated_tracker.adaptivity_metrics[iteration] = {
                'avg_entropy': np.mean([m['avg_entropy'] for m in adaptivity_list]),
                'max_probability': np.mean([m['max_probability'] for m in adaptivity_list]),
                'exploration_rate': np.mean([m['exploration_rate'] for m in adaptivity_list]),
                'temperature': np.mean([m['temperature'] for m in adaptivity_list])
            }

        # Average learning metrics
        if learning_list:
            all_improvements = [m['subproblem_improvements'] for m in learning_list]
            max_len = max(len(imp) for imp in all_improvements)
            avg_improvements = []
            for i in range(max_len):
                vals = [imp[i] for imp in all_improvements if i < len(imp)]
                avg_improvements.append(np.mean(vals) if vals else 0.0)

            aggregated_tracker.sequential_learning_metrics[iteration] = {
                'subproblem_improvements': avg_improvements,
                'cumulative_best_reward': np.mean([m['cumulative_best_reward'] for m in learning_list]),
                'subproblems_improved': int(np.mean([m['subproblems_improved'] for m in learning_list])),
                'total_subproblems': int(np.mean([m['total_subproblems'] for m in learning_list]))
            }

        # Average conflict metrics
        if conflict_list:
            aggregated_tracker.conflict_metrics[iteration] = {
                'dual_var_norm': np.mean([m['dual_var_norm'] for m in conflict_list]),
                'value_differentiation': np.mean([m['value_differentiation'] for m in conflict_list]),
                'coordination_score': np.mean([m['coordination_score'] for m in conflict_list]),
                'active_dual_vars': np.mean([m['active_dual_vars'] for m in conflict_list]),
                'actual_violations': np.mean([m['actual_violations'] for m in conflict_list])
            }

    print(f"✓ Aggregated metrics from {NUM_RUNS} runs")
    print(f"  - Total iterations: {len(aggregated_tracker.adaptivity_metrics)}\n")

    # ========================================================================
    # Print aggregated summary
    # ========================================================================
    print("="*80)
    print(" AGGREGATED TRACKING SUMMARY ".center(80))
    print("="*80)

    if aggregated_tracker.adaptivity_metrics:
        iterations = list(aggregated_tracker.adaptivity_metrics.keys())
        first_entropy = aggregated_tracker.adaptivity_metrics[iterations[0]]['avg_entropy']
        last_entropy = aggregated_tracker.adaptivity_metrics[iterations[-1]]['avg_entropy']
        entropy_reduction = (first_entropy - last_entropy) / first_entropy * 100

        print(f"\n1. ADAPTIVITY (averaged over {NUM_RUNS} runs):")
        print(f"   - Initial entropy: {first_entropy:.4f}")
        print(f"   - Final entropy: {last_entropy:.4f}")
        print(f"   - Reduction: {entropy_reduction:.1f}%")

    if aggregated_tracker.sequential_learning_metrics:
        iterations = list(aggregated_tracker.sequential_learning_metrics.keys())
        first_reward = aggregated_tracker.sequential_learning_metrics[iterations[0]]['cumulative_best_reward']
        last_reward = aggregated_tracker.sequential_learning_metrics[iterations[-1]]['cumulative_best_reward']
        improvement = last_reward - first_reward

        print(f"\n2. SEQUENTIAL LEARNING (averaged over {NUM_RUNS} runs):")
        print(f"   - Initial best reward: {first_reward:.4f}")
        print(f"   - Final best reward: {last_reward:.4f}")
        print(f"   - Total improvement: {improvement:.4f}")

    if aggregated_tracker.conflict_metrics:
        iterations = list(aggregated_tracker.conflict_metrics.keys())
        first_dual = aggregated_tracker.conflict_metrics[iterations[0]]['dual_var_norm']
        last_dual = aggregated_tracker.conflict_metrics[iterations[-1]]['dual_var_norm']
        dual_reduction = (first_dual - last_dual) / first_dual * 100 if first_dual > 0 else 0

        print(f"\n3. CONFLICT REDUCTION (averaged over {NUM_RUNS} runs):")
        print(f"   - Initial dual norm: {first_dual:.6f}")
        print(f"   - Final dual norm: {last_dual:.6f}")
        print(f"   - Reduction: {dual_reduction:.1f}%")

    print(f"\n{'='*80}\n")

    # ========================================================================
    # Analyze and visualize aggregated metrics
    # ========================================================================
    print("Generating visualizations from aggregated data...")
    quick_analysis(aggregated_tracker, save_plots=True, show_report=True)

    print("\n" + "="*80)
    print(" Done! Check 'plots/' directory for all visualizations ".center(80))
    print("="*80)
    print("\nGenerated files:")
    print("  - plots/adaptivity_metrics.png")
    print("  - plots/sequential_learning.png")
    print("  - plots/coordination_conflicts.png")
    print("  - plots/overview_dashboard.png")
    print("  - plots/*.csv (exported data)")
