
"""
Lagrangian Ablation Experiment
==============================
Compares: PBME (Full) vs PBME (No Lagrangian)

Uses MOCOEvaluator for standardized hypervolume ratio computation.
Run on your machine where torch is already installed.
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
import time

import sys

# === IMPORTS FROM YOUR PROJECT ===
from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import AdvancedDecompositionWrapper
from MOCO.problems import MultiObjectiveTSP, MultiObjectiveKnapsack, TriObjectiveTSP
from MOCO.evaluation import MOCOEvaluator


# ============================================================================
# METRICS (using MOCOEvaluator)
# ============================================================================

def get_evaluator_and_points(problem_type, problem_size):
    """Get MOCOEvaluator with correct reference points."""
    # Determine problem type string for evaluator
    if problem_type == 'TSP':
        ptype = 'BiTSP'
    else:
        ptype = 'BiKP'
    
    # Create a temporary evaluator to get standard points
    temp_eval = MOCOEvaluator(reference_point=(1, 1))  # placeholder
    points = temp_eval.get_standard_points(ptype, problem_size)
    
    if points is None:
        # Fallback
        if problem_type == 'TSP':
            points = {'reference': (20, 20, 20), 'ideal': (0, 0, 0)}
        else:
            points = {'reference': (5, 5), 'ideal': (30, 30)}
    
    # Create evaluator with correct reference point
    evaluator = MOCOEvaluator(reference_point=points['reference'])
    return evaluator, ptype, points


def compute_hv_ratio(evaluator, solutions, problem_type, problem_size):
    """Compute HV ratio using MOCOEvaluator."""
    # Format solutions as (solution, objectives) tuples
    formatted = [(None, tuple(obj)) for obj in solutions]
    
    # Use evaluator's method
    metrics = evaluator._calculate_hypervolume_metrics(
        formatted, 
        problem_type=problem_type, 
        problem_size=problem_size
    )
    return metrics['hv_ratio']


def compute_spread(front):
    """Compute spread/extent of Pareto front."""
    if len(front) < 2:
        return 0.0
    front = np.array(front)
    return np.sqrt(sum((front[:, i].max() - front[:, i].min())**2 for i in range(front.shape[1])))


# ============================================================================
# ABLATION RUNNER
# ============================================================================

def run_single(problem, problem_type, problem_size, use_lagrangian, evaluator, eval_ptype):
    """Run single optimization with or without Lagrangian."""
    
    params = {
        'decomposition_size': 20,
        'overlap': 4,
        'n_weight_vectors': 25,
        'nb_rounds': 80,
        'patience': 50,
        'max_iterations': 100,
        'use_lagrangian': use_lagrangian,
        'use_ftrl': True,
        'use_accelerated_dual': use_lagrangian,
        'use_diminishing_overlap': False,
        'dual_step_size': 10.0, # 1.0
        'overlap_decay_rate': 0.1
    }
    
    start_time = time.time()
    wrapper = AdvancedDecompositionWrapper(problem, **params)
    results = wrapper.run()
    runtime = time.time() - start_time
    
    # Extract valid objectives
    objectives = [r[1] for r in results if r[1] is not None]
    if problem_type == 'TSP':
        objectives = [o for o in objectives if all(x != float('inf') for x in o)]
    else:
        objectives = [o for o in objectives if all(x != -float('inf') for x in o)]
    
    # Compute HV ratio using MOCOEvaluator (standardized)
    hv_ratio = compute_hv_ratio(evaluator, objectives, eval_ptype, problem_size)
    spread = compute_spread(objectives)
    disagreement = compute_front_disagreement(objectives)

    return {
        'hv_ratio': hv_ratio,
        'n_solutions': len(objectives),
        'spread': spread,
        'runtime': runtime,
        'disagreement': disagreement,
        'objectives': objectives
    }


def run_ablation(problem_type='TSP', problem_size=20, n_runs=5):
    """Run full ablation study."""
    
    results = {'with_lagrangian': [], 'without_lagrangian': []}
    
    # Get evaluator with standardized reference points
    evaluator, eval_ptype, points = get_evaluator_and_points(problem_type, problem_size)
    
    print("=" * 60)
    print(f"LAGRANGIAN ABLATION: {problem_type}-{problem_size}, {n_runs} runs")
    print(f"Reference point: {points['reference']}, Ideal point: {points['ideal']}")
    print("=" * 60)
    
    for run in range(n_runs):
        print(f"\n--- Run {run + 1}/{n_runs} ---")
        
        # Create fresh problem
        if problem_type == 'TSP':
            problem = MultiObjectiveTSP(problem_size, 2)
        else:
            problem = MultiObjectiveKnapsack(n_items=problem_size, n_objectives=2, capacity=12.5)
        
        # With Lagrangian
        print("WITH Lagrangian...", end=" ")
        r = run_single(problem, problem_type, problem_size, True, evaluator, eval_ptype)
        results['with_lagrangian'].append(r)
        print(f"HV_ratio={r['hv_ratio']:.4f}, #Sol={r['n_solutions']}, Disagree={r['disagreement']:.4f}")
        
        # Without Lagrangian  
        print("WITHOUT Lagrangian...", end=" ")
        r = run_single(problem, problem_type, problem_size, False, evaluator, eval_ptype)
        results['without_lagrangian'].append(r)
        print(f"HV_ratio={r['hv_ratio']:.4f}, #Sol={r['n_solutions']}, Disagree={r['disagreement']:.4f}")
    
    return results


def print_latex_table(results):
    """Print results as LaTeX table."""
    
    stats = {}
    for variant in ['with_lagrangian', 'without_lagrangian']:
        runs = results[variant]
        stats[variant] = {
            'hv_mean': np.mean([r['hv_ratio'] for r in runs]),
            'hv_std': np.std([r['hv_ratio'] for r in runs]),
            'n_sol_mean': np.mean([r['n_solutions'] for r in runs]),
            'n_sol_std': np.std([r['n_solutions'] for r in runs]),
            'spread_mean': np.mean([r['spread'] for r in runs]),
            'spread_std': np.std([r['spread'] for r in runs]),
            'time_mean': np.mean([r['runtime'] for r in runs]),
            'dis_mean': np.mean([r['disagreement'] for r in runs]),
            'dis_std': np.std([r['disagreement'] for r in runs]),

        }
    
    print("\n" + "=" * 60)
    print("RESULTS (HV Ratio - Standardized)")
    print("=" * 60)
    
    # Console
    print(f"\n{'Method':<25} {'HV Ratio':>15} {'#Sol':>10} {'Spread':>12} {'Disagree':>12}")
    print("-" * 65)
    for v, name in [('with_lagrangian', 'PBME (Full)'), ('without_lagrangian', 'PBME (No Lagrangian)')]:
        s = stats[v]
        print(f"{name:<25} {s['hv_mean']:.4f}±{s['hv_std']:.4f}  {s['n_sol_mean']:.1f}±{s['n_sol_std']:.1f}  {s['spread_mean']:.3f}±{s['spread_std']:.3f}  {s['dis_mean']:.3f}±{s['dis_std']:.3f}")
    
    # LaTeX
    print("\n--- LaTeX ---")
    print(r"\begin{table}[h]")
    print(r"\caption{Ablation: Effect of Lagrangian Coordination}")
    print(r"\label{tab:lagrangian-ablation}")
    print(r"\centering")
    print(r"\begin{tabular}{lccc}")
    print(r"\toprule")
    print(r"Method & HV Ratio $\uparrow$ & \#Solutions & Spread $\uparrow$  &  Disagreement $\downarrow$\\")
    print(r"\midrule")
    
    for v, name in [('with_lagrangian', r'\textbf{PBME (Full)}'), ('without_lagrangian', 'PBME (No Lagrangian)')]:
        s = stats[v]
        print(f"{name} & {s['hv_mean']:.3f} $\\pm$ {s['hv_std']:.3f} & {s['n_sol_mean']:.1f} $\\pm$ {s['n_sol_std']:.1f} & {s['spread_mean']:.3f} $\\pm$ {s['spread_std']:.3f} & {s['dis_mean']:.3f} $\\pm$ {s['dis_std']:.3f}\\\\")
    
    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")
    
    # Improvement
    imp = (stats['with_lagrangian']['hv_mean'] - stats['without_lagrangian']['hv_mean']) / max(stats['without_lagrangian']['hv_mean'], 1e-6) * 100
    print(f"\nHV Ratio Improvement: {imp:+.2f}%")
    
    return stats

def plot_results(results, save_dir='.'):
    """Create plots - each saved separately as png and pdf."""
    
    # Stats
    hvs_with = [r['hv_ratio'] for r in results['with_lagrangian']]
    hvs_without = [r['hv_ratio'] for r in results['without_lagrangian']]
    methods = ['With\nLagrangian', 'Without\nLagrangian']
    
    # Plot 1: Bar chart
    fig1, ax = plt.subplots(figsize=(5, 4))
    means = [np.mean(hvs_with), np.mean(hvs_without)]
    stds = [np.std(hvs_with), np.std(hvs_without)]
    bars = ax.bar(methods, means, yerr=stds, capsize=5, color=['#2ecc71', '#e74c3c'], alpha=0.8)
    ax.set_ylabel('HV Ratio')
    ax.set_title('Hypervolume Ratio Comparison')
    ax.grid(axis='y', alpha=0.3)
    for bar, m in zip(bars, means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{m:.3f}', ha='center', fontsize=10)
    plt.tight_layout()
    fig1.savefig(f'{save_dir}/ablation_hv_bar.png', dpi=150, bbox_inches='tight')
    fig1.savefig(f'{save_dir}/ablation_hv_bar.pdf', bbox_inches='tight')
    print(f"Saved: ablation_hv_bar.png/pdf")
    
    # Plot 2: Box plot
    fig2, ax = plt.subplots(figsize=(5, 4))
    bp = ax.boxplot([hvs_with, hvs_without], labels=methods, patch_artist=True)
    bp['boxes'][0].set_facecolor('#2ecc71')
    bp['boxes'][1].set_facecolor('#e74c3c')
    ax.set_ylabel('HV Ratio')
    ax.set_title('HV Ratio Distribution')
    ax.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    fig2.savefig(f'{save_dir}/ablation_hv_box.png', dpi=150, bbox_inches='tight')
    fig2.savefig(f'{save_dir}/ablation_hv_box.pdf', bbox_inches='tight')
    print(f"Saved: ablation_hv_box.png/pdf")
    
    # Plot 3: Pareto fronts (combined from ALL runs)
    fig3, ax = plt.subplots(figsize=(6, 5))
    
    all_obj_with = []
    all_obj_without = []
    for r in results['with_lagrangian']:
        all_obj_with.extend(r['objectives'])
    for r in results['without_lagrangian']:
        all_obj_without.extend(r['objectives'])
    
    if len(all_obj_with) > 0:
        obj_with = np.array(all_obj_with)
        ax.scatter(obj_with[:, 0], obj_with[:, 1], c='#2ecc71', label='With Lagrangian', s=40, alpha=0.5)
    if len(all_obj_without) > 0:
        obj_without = np.array(all_obj_without)
        ax.scatter(obj_without[:, 0], obj_without[:, 1], c='#e74c3c', label='Without Lagrangian', s=40, alpha=0.5, marker='s')
    
    ax.set_xlabel('Objective 1')
    ax.set_ylabel('Objective 2')
    ax.set_title('Pareto Fronts (All Runs Combined)')
    ax.legend()
    ax.grid(alpha=0.3)
    plt.tight_layout()
    fig3.savefig(f'{save_dir}/ablation_pareto.png', dpi=150, bbox_inches='tight')
    fig3.savefig(f'{save_dir}/ablation_pareto.pdf', bbox_inches='tight')
    print(f"Saved: ablation_pareto.png/pdf")
    
    plt.show()

def run_scaling_experiment(problem_type='TSP', sizes=[20, 50, 100], n_runs=3):
    """Run ablation across multiple problem sizes to show scaling effect."""
    
    scaling_results = {
        'sizes': sizes,
        'hv_with': [],
        'hv_without': [],
        'hv_with_std': [],
        'hv_without_std': [],
        'hv_diff': [],
        'hv_diff_std': [],
    }
    
    print("=" * 60)
    print(f"SCALING EXPERIMENT: {problem_type}")
    print(f"Sizes: {sizes}, Runs per size: {n_runs}")
    print("=" * 60)
    
    for size in sizes:
        print(f"\n>>> Problem size: {size}")
        
        evaluator, eval_ptype, points = get_evaluator_and_points(problem_type, size)
        
        hvs_with = []
        hvs_without = []
        
        for run in range(n_runs):
            # Create problem
            if problem_type == 'TSP':
                problem = MultiObjectiveTSP(size, 2)
            else:
                problem = MultiObjectiveKnapsack(size, 2)
            
            # With Lagrangian
            r_with = run_single(problem, problem_type, size, True, evaluator, eval_ptype)
            hvs_with.append(r_with['hv_ratio'])
            
            # Without Lagrangian
            r_without = run_single(problem, problem_type, size, False, evaluator, eval_ptype)
            hvs_without.append(r_without['hv_ratio'])
            
            print(f"  Run {run+1}: With={r_with['hv_ratio']:.4f}, Without={r_without['hv_ratio']:.4f}, Diff={r_with['hv_ratio']-r_without['hv_ratio']:.4f}")
        
        # Store results
        scaling_results['hv_with'].append(np.mean(hvs_with))
        scaling_results['hv_without'].append(np.mean(hvs_without))
        scaling_results['hv_with_std'].append(np.std(hvs_with))
        scaling_results['hv_without_std'].append(np.std(hvs_without))
        
        # Compute pairwise differences for proper std
        diffs = [w - wo for w, wo in zip(hvs_with, hvs_without)]
        scaling_results['hv_diff'].append(np.mean(diffs))
        scaling_results['hv_diff_std'].append(np.std(diffs))
    
    return scaling_results


def plot_scaling_results(scaling_results, save_dir='.'):
    """Plot scaling experiment results - each saved separately."""
    
    sizes = scaling_results['sizes']
    
    # Plot 1: HV with and without Lagrangian across scales
    fig1, ax = plt.subplots(figsize=(6, 5))
    ax.errorbar(sizes, scaling_results['hv_with'], yerr=scaling_results['hv_with_std'],
                marker='o', capsize=5, color='#2ecc71', label='With Lagrangian', linewidth=2, markersize=8)
    ax.errorbar(sizes, scaling_results['hv_without'], yerr=scaling_results['hv_without_std'],
                marker='s', capsize=5, color='#e74c3c', label='Without Lagrangian', linewidth=2, markersize=8)
    ax.set_xlabel('Problem Size (n)')
    ax.set_ylabel('HV Ratio')
    ax.set_title('HV Ratio vs Problem Size')
    ax.legend()
    ax.grid(alpha=0.3)
    ax.set_xticks(sizes)
    plt.tight_layout()
    fig1.savefig(f'{save_dir}/scaling_hv_comparison.png', dpi=150, bbox_inches='tight')
    fig1.savefig(f'{save_dir}/scaling_hv_comparison.pdf', bbox_inches='tight')
    print(f"Saved: scaling_hv_comparison.png/pdf")
    
    # Plot 2: HV Difference (With - Without) across scales
    fig2, ax = plt.subplots(figsize=(6, 5))
    ax.errorbar(sizes, scaling_results['hv_diff'], yerr=scaling_results['hv_diff_std'],
                marker='D', capsize=5, color='#3498db', linewidth=2, markersize=8)
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax.fill_between(sizes, 
                    [d - s for d, s in zip(scaling_results['hv_diff'], scaling_results['hv_diff_std'])],
                    [d + s for d, s in zip(scaling_results['hv_diff'], scaling_results['hv_diff_std'])],
                    alpha=0.2, color='#3498db')
    ax.set_xlabel('Problem Size (n)')
    ax.set_ylabel('HV Difference (With - Without)')
    ax.set_title('Lagrangian Benefit vs Problem Size')
    ax.grid(alpha=0.3)
    ax.set_xticks(sizes)
    
    # Add annotation if benefit increases with scale
    if len(scaling_results['hv_diff']) >= 2 and scaling_results['hv_diff'][-1] > scaling_results['hv_diff'][0]:
        ax.annotate('Lagrangian more\nbeneficial at scale', 
                    xy=(sizes[-1], scaling_results['hv_diff'][-1]),
                    xytext=(sizes[-1]*0.7, scaling_results['hv_diff'][-1]*1.5),
                    arrowprops=dict(arrowstyle='->', color='gray'),
                    fontsize=10, ha='center')
    
    plt.tight_layout()
    fig2.savefig(f'{save_dir}/scaling_hv_diff.png', dpi=150, bbox_inches='tight')
    fig2.savefig(f'{save_dir}/scaling_hv_diff.pdf', bbox_inches='tight')
    print(f"Saved: scaling_hv_diff.png/pdf")
    
    plt.show()
    
    # Print summary table
    print("\n--- Scaling Summary ---")
    print(f"{'Size':<10} {'HV With':>15} {'HV Without':>15} {'Diff':>15}")
    print("-" * 55)
    for i, size in enumerate(sizes):
        print(f"{size:<10} {scaling_results['hv_with'][i]:.4f}±{scaling_results['hv_with_std'][i]:.4f}  "
              f"{scaling_results['hv_without'][i]:.4f}±{scaling_results['hv_without_std'][i]:.4f}  "
              f"{scaling_results['hv_diff'][i]:+.4f}±{scaling_results['hv_diff_std'][i]:.4f}")

def compute_front_disagreement(objectives, tol=1e-6):
    """
    Measures inconsistency / fragmentation of the Pareto front.
    Higher = worse coordination.
    
    Intuition:
    - Well-coordinated methods produce smooth, structured fronts
    - Poor coordination leads to scattered, conflicting points
    """
    if len(objectives) < 3:
        return 0.0

    objs = np.array(objectives)

    # Pairwise distances
    dists = []
    for i in range(len(objs)):
        for j in range(i + 1, len(objs)):
            dists.append(np.linalg.norm(objs[i] - objs[j]))

    dists = np.array(dists)

    # Normalize by front diameter
    diameter = np.linalg.norm(objs.max(axis=0) - objs.min(axis=0)) + tol
    return float(np.std(dists / diameter))


# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    # === CONFIG ===
    PROBLEM_TYPE = 'TSP'   # 'TSP' or 'Knapsack'
    PROBLEM_SIZE = 50
    N_RUNS = 10            # Increase to 5-10 for paper

    RUN_SCALING = True     # Set True to run scaling experiment
    SCALING_SIZES = [20, 50, 100]  # Problem sizes for scaling
    SAVE_DIR = "."
    
    # Run single-size ablation
    print("\n" + "="*60)
    print("SINGLE SIZE ABLATION")
    print("="*60)
    results = run_ablation(PROBLEM_TYPE, PROBLEM_SIZE, N_RUNS)
    stats = print_latex_table(results)
    plot_results(results, save_dir=SAVE_DIR)
    
    # Run scaling experiment
    # if RUN_SCALING:
    #     print("\n" + "="*60)
    #     print("SCALING EXPERIMENT")
    #     print("="*60)
    #     scaling_results = run_scaling_experiment(PROBLEM_TYPE, SCALING_SIZES, n_runs=N_RUNS)
    #     plot_scaling_results(scaling_results, save_dir=SAVE_DIR)
    
    print("\n✅ Done!")






