"""
Lagrangian Coordination Stress Tests
====================================
Better designed experiments to demonstrate when Lagrangian coordination matters.

Key insight: Lagrangian coordination helps when:
1. High overlap between subproblems (more shared variables)
2. Objectives conflict on shared regions
3. Many subproblems need to agree

These tests create scenarios that stress-test coordination.
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional
import time
from collections import defaultdict
from dataclasses import dataclass

import sys

from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import AdvancedDecompositionWrapper
from MOCO.problems import MultiObjectiveTSP, MultiObjectiveKnapsack
from MOCO.evaluation import MOCOEvaluator


# ============================================================================
# CUSTOM PROBLEM: Conflicting Objectives in Overlapping Regions
# ============================================================================

class ConflictingObjectivesTSP:
    """
    TSP where objectives are designed to CONFLICT on certain city pairs.
    
    Key idea: Create "conflict zones" where obj1 wants short edges but obj2 wants long edges.
    This forces the optimizer to coordinate decisions in overlapping regions.
    """
    
    def __init__(self, n_cities: int, conflict_ratio: float = 0.3):
        self.n_cities = n_cities
        self._m_objectives = 2
        self.conflict_ratio = conflict_ratio
        
        # Base coordinates
        base_coords = np.random.uniform(0, 1, size=(n_cities, 2))
        
        # Objective 1: Use base coordinates
        coords1 = base_coords.copy()
        
        # Objective 2: Perturb coordinates to create conflicts
        # For conflict_ratio fraction of cities, move them far from neighbors in obj1
        coords2 = base_coords.copy()
        n_conflict = int(n_cities * conflict_ratio)
        conflict_cities = np.random.choice(n_cities, n_conflict, replace=False)
        
        for city in conflict_cities:
            # Move this city to opposite corner for obj2
            coords2[city] = 1.0 - coords2[city]
        
        self.coordinates = [coords1, coords2]
        self.conflict_cities = set(conflict_cities)
        
        # Pre-compute distance matrices
        self.distance_matrices = [self._compute_distance_matrix(i) for i in range(2)]
        self.distances1 = self.distance_matrices[0]
        self.distances2 = self.distance_matrices[1]
    
    def _compute_distance_matrix(self, obj_idx):
        coords = self.coordinates[obj_idx]
        n = self.n_cities
        dist = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                if i != j:
                    dist[i, j] = np.linalg.norm(coords[i] - coords[j])
        return dist
    
    def evaluate(self, solution: List[int]) -> Tuple[float, float]:
        if len(solution) != self.n_cities:
            raise ValueError("Invalid tour length")
        
        objectives = []
        for m in range(2):
            total = 0
            for i in range(self.n_cities):
                from_city = solution[i]
                to_city = solution[(i + 1) % self.n_cities]
                total += self.distance_matrices[m][from_city, to_city]
            objectives.append(total)
        return tuple(objectives)
    
    def random_solution(self) -> List[int]:
        tour = list(range(self.n_cities))
        np.random.shuffle(tour)
        return tour
    
    @property
    def num_objectives(self) -> int:
        return 2


class ConflictingKnapsack:
    """
    Knapsack where items have conflicting values across objectives.
    
    Some items are "controversial": high value for obj1, low for obj2, and vice versa.
    This creates natural disagreement in overlapping regions of decomposition.
    """
    
    def __init__(self, n_items: int, n_objectives: int = 2, 
                 conflict_ratio: float = 0.4, capacity: float = 12.5):
        self.n_items = n_items
        self._n_objectives = n_objectives
        self.conflict_ratio = conflict_ratio
        self.capacity = capacity
        
        # Generate weights
        self.weights = np.random.uniform(1, 5, size=n_items)
        
        # Generate values with conflicts
        base_values = np.random.uniform(5, 15, size=n_items)
        
        # Select controversial items
        n_conflict = int(n_items * conflict_ratio)
        conflict_items = np.random.choice(n_items, n_conflict, replace=False)
        self.conflict_items = set(conflict_items)
        
        # Objective 1 values
        self.values1 = base_values.copy()
        
        # Objective 2: Flip values for conflict items
        self.values2 = base_values.copy()
        for item in conflict_items:
            # High in obj1 -> low in obj2, and vice versa
            max_val = base_values.max()
            self.values2[item] = max_val - base_values[item] + 5  # Inverse relationship
        
        self.values = [self.values1, self.values2]
    
    def evaluate(self, selection: List[int]) -> Tuple[float, float]:
        """selection is list of selected item indices"""
        total_weight = sum(self.weights[i] for i in selection)
        
        if total_weight > self.capacity:
            return (-float('inf'), -float('inf'))
        
        obj1 = sum(self.values1[i] for i in selection)
        obj2 = sum(self.values2[i] for i in selection)
        
        return (obj1, obj2)
    
    def random_solution(self) -> List[int]:
        selection = []
        total_weight = 0
        items = list(range(self.n_items))
        np.random.shuffle(items)
        
        for item in items:
            if total_weight + self.weights[item] <= self.capacity:
                selection.append(item)
                total_weight += self.weights[item]
        
        return selection
    
    @property
    def num_objectives(self) -> int:
        return self._n_objectives


# ============================================================================
# STRESS TEST 1: Vary Overlap Ratio
# ============================================================================

def stress_test_overlap_ratio(problem_type='TSP', problem_size=50, n_runs=5):
    """
    Test hypothesis: Higher overlap ratio → more Lagrangian benefit
    
    Overlap creates shared variables that need coordination.
    """
    
    # Test different overlap ratios (overlap / decomposition_size)
    decomp_size = 15  # Fixed decomposition size
    overlap_values = [2, 5, 8, 10, 12]  # Increasing overlap
    
    results = {ov: {'with': [], 'without': []} for ov in overlap_values}
    
    print("=" * 70)
    print("STRESS TEST 1: Varying Overlap Ratio")
    print(f"Problem: {problem_type}-{problem_size}, Decomposition size: {decomp_size}")
    print("=" * 70)
    
    for overlap in overlap_values:
        ratio = overlap / decomp_size
        print(f"\n--- Overlap: {overlap} (ratio: {ratio:.1%}) ---")
        
        for run in range(n_runs):
            # Create problem with conflicts
            if problem_type == 'TSP':
                problem = ConflictingObjectivesTSP(problem_size, conflict_ratio=0.4)
            else:
                problem = ConflictingKnapsack(problem_size, conflict_ratio=0.4)
            
            base_params = {
                'decomposition_size': decomp_size,
                'overlap': overlap,
                'n_weight_vectors': 25,
                'nb_rounds': 60,
                'patience': 40,
                'max_iterations': 80,
                'use_ftrl': True,
                'use_diminishing_overlap': False,  # IMPORTANT: Keep overlap constant!
                'dual_step_size': 5.0,
            }
            
            # With Lagrangian
            params = {**base_params, 'use_lagrangian': True, 'use_accelerated_dual': True}
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            hv_with = compute_hv_simple(res, problem_type)
            results[overlap]['with'].append(hv_with)
            
            # Without Lagrangian
            params = {**base_params, 'use_lagrangian': False, 'use_accelerated_dual': False}
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            hv_without = compute_hv_simple(res, problem_type)
            results[overlap]['without'].append(hv_without)
            
            print(f"  Run {run+1}: With={hv_with:.4f}, Without={hv_without:.4f}, Δ={hv_with-hv_without:+.4f}")
    
    return results, decomp_size


def plot_overlap_stress_test(results, decomp_size, save_dir='.'):
    """Plot overlap ratio stress test results."""
    
    overlaps = sorted(results.keys())
    ratios = [ov / decomp_size for ov in overlaps]
    
    means_with = [np.mean(results[ov]['with']) for ov in overlaps]
    means_without = [np.mean(results[ov]['without']) for ov in overlaps]
    stds_with = [np.std(results[ov]['with']) for ov in overlaps]
    stds_without = [np.std(results[ov]['without']) for ov in overlaps]
    
    # Compute improvement
    improvements = [np.mean(results[ov]['with']) - np.mean(results[ov]['without']) for ov in overlaps]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot 1: HV comparison
    ax = axes[0]
    ax.errorbar(ratios, means_with, yerr=stds_with, marker='o', capsize=5,
                label='With Lagrangian', color='#2ecc71', linewidth=2)
    ax.errorbar(ratios, means_without, yerr=stds_without, marker='s', capsize=5,
                label='Without Lagrangian', color='#e74c3c', linewidth=2)
    ax.set_xlabel('Overlap Ratio (overlap / decomp_size)')
    ax.set_ylabel('HV Ratio')
    ax.set_title('HV vs Overlap Ratio')
    ax.legend()
    ax.grid(alpha=0.3)
    
    # Plot 2: Improvement vs overlap
    ax = axes[1]
    ax.bar(ratios, improvements, width=0.05, color='#3498db', alpha=0.8)
    ax.axhline(y=0, color='gray', linestyle='--')
    ax.set_xlabel('Overlap Ratio')
    ax.set_ylabel('HV Improvement (With - Without)')
    ax.set_title('Lagrangian Benefit vs Overlap Ratio')
    ax.grid(alpha=0.3, axis='y')
    
    # Add trend annotation
    if improvements[-1] > improvements[0]:
        ax.annotate('↑ More overlap = more benefit',
                    xy=(ratios[-1], improvements[-1]),
                    xytext=(ratios[-2], max(improvements) * 1.2),
                    arrowprops=dict(arrowstyle='->', color='gray'),
                    fontsize=10)
    
    plt.tight_layout()
    fig.savefig(f'{save_dir}/stress_overlap_ratio.png', dpi=150, bbox_inches='tight')
    fig.savefig(f'{save_dir}/stress_overlap_ratio.pdf', bbox_inches='tight')
    print(f"Saved: stress_overlap_ratio.png/pdf")
    plt.show()


# ============================================================================
# STRESS TEST 2: Vary Number of Subproblems
# ============================================================================

def stress_test_n_subproblems(problem_type='TSP', problem_size=100, n_runs=5):
    """
    Test hypothesis: More subproblems → more Lagrangian benefit
    
    More subproblems = more places where coordination is needed.
    """
    
    # Different decomposition sizes (smaller = more subproblems)
    decomp_sizes = [50, 30, 20, 15, 10]  # Decreasing size = more subproblems
    overlap_ratio = 0.4  # Fixed overlap ratio
    
    results = {ds: {'with': [], 'without': [], 'n_subproblems': None} for ds in decomp_sizes}
    
    print("=" * 70)
    print("STRESS TEST 2: Varying Number of Subproblems")
    print(f"Problem: {problem_type}-{problem_size}, Overlap ratio: {overlap_ratio:.0%}")
    print("=" * 70)
    
    for decomp_size in decomp_sizes:
        overlap = int(decomp_size * overlap_ratio)
        
        # Estimate number of subproblems
        step = max(1, decomp_size - overlap)
        n_subproblems = (problem_size // step) + 1
        results[decomp_size]['n_subproblems'] = n_subproblems
        
        print(f"\n--- Decomp size: {decomp_size}, ~{n_subproblems} subproblems ---")
        
        for run in range(n_runs):
            if problem_type == 'TSP':
                problem = ConflictingObjectivesTSP(problem_size, conflict_ratio=0.4)
            else:
                problem = ConflictingKnapsack(problem_size, conflict_ratio=0.4)
            
            base_params = {
                'decomposition_size': decomp_size,
                'overlap': overlap,
                'n_weight_vectors': 25,
                'nb_rounds': 60,
                'patience': 40,
                'max_iterations': 80,
                'use_ftrl': True,
                'use_diminishing_overlap': False,
                'dual_step_size': 5.0,
            }
            
            # With Lagrangian
            params = {**base_params, 'use_lagrangian': True, 'use_accelerated_dual': True}
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            hv_with = compute_hv_simple(res, problem_type)
            results[decomp_size]['with'].append(hv_with)
            
            # Without Lagrangian
            params = {**base_params, 'use_lagrangian': False, 'use_accelerated_dual': False}
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            hv_without = compute_hv_simple(res, problem_type)
            results[decomp_size]['without'].append(hv_without)
            
            print(f"  Run {run+1}: With={hv_with:.4f}, Without={hv_without:.4f}")
    
    return results


def plot_subproblems_stress_test(results, save_dir='.'):
    """Plot subproblems stress test results."""
    
    decomp_sizes = sorted(results.keys(), reverse=True)
    n_subproblems = [results[ds]['n_subproblems'] for ds in decomp_sizes]
    
    means_with = [np.mean(results[ds]['with']) for ds in decomp_sizes]
    means_without = [np.mean(results[ds]['without']) for ds in decomp_sizes]
    improvements = [np.mean(results[ds]['with']) - np.mean(results[ds]['without']) for ds in decomp_sizes]
    
    fig, ax = plt.subplots(figsize=(8, 5))
    
    x = np.arange(len(decomp_sizes))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, means_with, width, label='With Lagrangian', color='#2ecc71')
    bars2 = ax.bar(x + width/2, means_without, width, label='Without Lagrangian', color='#e74c3c')
    
    ax.set_xlabel('Number of Subproblems (approx)')
    ax.set_ylabel('HV Ratio')
    ax.set_title('HV vs Number of Subproblems')
    ax.set_xticks(x)
    ax.set_xticklabels([f'{n}\n(d={d})' for n, d in zip(n_subproblems, decomp_sizes)])
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Add improvement annotations
    for i, (b1, b2, imp) in enumerate(zip(bars1, bars2, improvements)):
        if imp > 0:
            ax.annotate(f'+{imp:.3f}', xy=(i, max(b1.get_height(), b2.get_height()) + 0.01),
                       ha='center', fontsize=9, color='#2ecc71')
    
    plt.tight_layout()
    fig.savefig(f'{save_dir}/stress_n_subproblems.png', dpi=150, bbox_inches='tight')
    fig.savefig(f'{save_dir}/stress_n_subproblems.pdf', bbox_inches='tight')
    print(f"Saved: stress_n_subproblems.png/pdf")
    plt.show()


# ============================================================================
# STRESS TEST 3: Conflict Intensity
# ============================================================================

def stress_test_conflict_intensity(problem_type='TSP', problem_size=50, n_runs=5):
    """
    Test hypothesis: Higher conflict between objectives → more Lagrangian benefit
    
    Conflict = objectives want different things for the same variables.
    """
    
    conflict_ratios = [0.0, 0.2, 0.4, 0.6, 0.8]
    
    results = {cr: {'with': [], 'without': []} for cr in conflict_ratios}
    
    print("=" * 70)
    print("STRESS TEST 3: Varying Objective Conflict Intensity")
    print(f"Problem: {problem_type}-{problem_size}")
    print("=" * 70)
    
    for conflict_ratio in conflict_ratios:
        print(f"\n--- Conflict ratio: {conflict_ratio:.0%} ---")
        
        for run in range(n_runs):
            if problem_type == 'TSP':
                problem = ConflictingObjectivesTSP(problem_size, conflict_ratio=conflict_ratio)
            else:
                problem = ConflictingKnapsack(problem_size, conflict_ratio=conflict_ratio)
            
            base_params = {
                'decomposition_size': 15,
                'overlap': 8,  # High overlap to stress coordination
                'n_weight_vectors': 25,
                'nb_rounds': 60,
                'patience': 40,
                'max_iterations': 80,
                'use_ftrl': True,
                'use_diminishing_overlap': False,
                'dual_step_size': 5.0,
            }
            
            # With Lagrangian
            params = {**base_params, 'use_lagrangian': True, 'use_accelerated_dual': True}
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            hv_with = compute_hv_simple(res, problem_type)
            results[conflict_ratio]['with'].append(hv_with)
            
            # Without Lagrangian
            params = {**base_params, 'use_lagrangian': False, 'use_accelerated_dual': False}
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            hv_without = compute_hv_simple(res, problem_type)
            results[conflict_ratio]['without'].append(hv_without)
            
            print(f"  Run {run+1}: With={hv_with:.4f}, Without={hv_without:.4f}")
    
    return results


def plot_conflict_stress_test(results, save_dir='.'):
    """Plot conflict intensity stress test results."""
    
    conflicts = sorted(results.keys())
    
    means_with = [np.mean(results[cr]['with']) for cr in conflicts]
    means_without = [np.mean(results[cr]['without']) for cr in conflicts]
    stds_with = [np.std(results[cr]['with']) for cr in conflicts]
    stds_without = [np.std(results[cr]['without']) for cr in conflicts]
    improvements = [m_w - m_wo for m_w, m_wo in zip(means_with, means_without)]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot 1: HV comparison
    ax = axes[0]
    ax.errorbar(conflicts, means_with, yerr=stds_with, marker='o', capsize=5,
                label='With Lagrangian', color='#2ecc71', linewidth=2)
    ax.errorbar(conflicts, means_without, yerr=stds_without, marker='s', capsize=5,
                label='Without Lagrangian', color='#e74c3c', linewidth=2)
    ax.set_xlabel('Conflict Ratio')
    ax.set_ylabel('HV Ratio')
    ax.set_title('HV vs Objective Conflict')
    ax.legend()
    ax.grid(alpha=0.3)
    
    # Plot 2: Improvement
    ax = axes[1]
    colors = ['#3498db' if imp > 0 else '#95a5a6' for imp in improvements]
    ax.bar(conflicts, improvements, width=0.15, color=colors, alpha=0.8)
    ax.axhline(y=0, color='gray', linestyle='--')
    ax.set_xlabel('Conflict Ratio')
    ax.set_ylabel('HV Improvement')
    ax.set_title('Lagrangian Benefit vs Conflict')
    ax.grid(alpha=0.3, axis='y')
    
    plt.tight_layout()
    fig.savefig(f'{save_dir}/stress_conflict.png', dpi=150, bbox_inches='tight')
    fig.savefig(f'{save_dir}/stress_conflict.pdf', bbox_inches='tight')
    print(f"Saved: stress_conflict.png/pdf")
    plt.show()


# ============================================================================
# STRESS TEST 4: Track Coordination Dynamics Over Iterations
# ============================================================================

def stress_test_dynamics(problem_type='TSP', problem_size=50, n_runs=3):
    """
    Track how coordination (disagreement) evolves over iterations.
    
    Hypothesis: With Lagrangian, disagreement should decrease faster.
    """
    
    print("=" * 70)
    print("STRESS TEST 4: Coordination Dynamics Over Time")
    print(f"Problem: {problem_type}-{problem_size}")
    print("=" * 70)
    
    # We need to modify the wrapper to track per-iteration disagreement
    # For now, let's measure disagreement at different checkpoints
    
    checkpoints = [10, 25, 40, 60, 80]
    
    results = {
        'with': {cp: [] for cp in checkpoints},
        'without': {cp: [] for cp in checkpoints}
    }
    
    for run in range(n_runs):
        print(f"\n--- Run {run+1}/{n_runs} ---")
        
        if problem_type == 'TSP':
            problem = ConflictingObjectivesTSP(problem_size, conflict_ratio=0.4)
        else:
            problem = ConflictingKnapsack(problem_size, conflict_ratio=0.4)
        
        for cp in checkpoints:
            for use_lag, key in [(True, 'with'), (False, 'without')]:
                params = {
                    'decomposition_size': 15,
                    'overlap': 8,
                    'n_weight_vectors': 25,
                    'nb_rounds': cp,  # Stop at checkpoint
                    'patience': 200,  # Don't early stop
                    'max_iterations': 80,
                    'use_ftrl': True,
                    'use_lagrangian': use_lag,
                    'use_accelerated_dual': use_lag,
                    'use_diminishing_overlap': False,
                    'dual_step_size': 5.0,
                }
                
                wrapper = AdvancedDecompositionWrapper(problem, **params)
                res = wrapper.run()
                
                # Compute disagreement of final front
                objectives = [r[1] for r in res if r[1] is not None]
                if problem_type == 'TSP':
                    objectives = [o for o in objectives if all(x != float('inf') for x in o)]
                else:
                    objectives = [o for o in objectives if all(x != -float('inf') for x in o)]
                
                disagreement = compute_front_disagreement(objectives)
                results[key][cp].append(disagreement)
        
        print(f"  Done run {run+1}")
    
    return results, checkpoints


def plot_dynamics_stress_test(results, checkpoints, save_dir='.'):
    """Plot coordination dynamics."""
    
    fig, ax = plt.subplots(figsize=(8, 5))
    
    means_with = [np.mean(results['with'][cp]) for cp in checkpoints]
    means_without = [np.mean(results['without'][cp]) for cp in checkpoints]
    stds_with = [np.std(results['with'][cp]) for cp in checkpoints]
    stds_without = [np.std(results['without'][cp]) for cp in checkpoints]
    
    ax.errorbar(checkpoints, means_with, yerr=stds_with, marker='o', capsize=5,
                label='With Lagrangian', color='#2ecc71', linewidth=2)
    ax.errorbar(checkpoints, means_without, yerr=stds_without, marker='s', capsize=5,
                label='Without Lagrangian', color='#e74c3c', linewidth=2)
    
    ax.set_xlabel('Iterations')
    ax.set_ylabel('Front Disagreement (lower = better)')
    ax.set_title('Coordination Dynamics: Disagreement Over Time')
    ax.legend()
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    fig.savefig(f'{save_dir}/stress_dynamics.png', dpi=150, bbox_inches='tight')
    fig.savefig(f'{save_dir}/stress_dynamics.pdf', bbox_inches='tight')
    print(f"Saved: stress_dynamics.png/pdf")
    plt.show()


# ============================================================================
# STRESS TEST 5: Dual Step Size Sensitivity
# ============================================================================

def stress_test_dual_step_size(problem_type='TSP', problem_size=50, n_runs=5):
    """
    Test sensitivity to dual step size.
    
    This shows that Lagrangian coordination is robust and can be tuned.
    """
    
    step_sizes = [0.1, 1.0, 5.0, 10.0, 20.0]
    
    results = {ss: [] for ss in step_sizes}
    baseline_results = []
    
    print("=" * 70)
    print("STRESS TEST 5: Dual Step Size Sensitivity")
    print(f"Problem: {problem_type}-{problem_size}")
    print("=" * 70)
    
    for run in range(n_runs):
        print(f"\n--- Run {run+1}/{n_runs} ---")
        
        if problem_type == 'TSP':
            problem = ConflictingObjectivesTSP(problem_size, conflict_ratio=0.4)
        else:
            problem = ConflictingKnapsack(problem_size, conflict_ratio=0.4)
        
        # Baseline (no Lagrangian)
        params_base = {
            'decomposition_size': 15,
            'overlap': 8,
            'n_weight_vectors': 25,
            'nb_rounds': 60,
            'patience': 40,
            'max_iterations': 80,
            'use_ftrl': True,
            'use_lagrangian': False,
            'use_diminishing_overlap': False,
        }
        wrapper = AdvancedDecompositionWrapper(problem, **params_base)
        res = wrapper.run()
        baseline_results.append(compute_hv_simple(res, problem_type))
        
        # Different step sizes
        for ss in step_sizes:
            params = {
                **params_base,
                'use_lagrangian': True,
                'use_accelerated_dual': True,
                'dual_step_size': ss,
            }
            wrapper = AdvancedDecompositionWrapper(problem, **params)
            res = wrapper.run()
            results[ss].append(compute_hv_simple(res, problem_type))
    
    return results, baseline_results


def plot_step_size_stress_test(results, baseline_results, save_dir='.'):
    """Plot step size sensitivity."""
    
    step_sizes = sorted(results.keys())
    
    means = [np.mean(results[ss]) for ss in step_sizes]
    stds = [np.std(results[ss]) for ss in step_sizes]
    baseline_mean = np.mean(baseline_results)
    baseline_std = np.std(baseline_results)
    
    fig, ax = plt.subplots(figsize=(8, 5))
    
    ax.errorbar(step_sizes, means, yerr=stds, marker='o', capsize=5,
                color='#3498db', linewidth=2, label='With Lagrangian')
    ax.axhline(y=baseline_mean, color='#e74c3c', linestyle='--', 
               label=f'No Lagrangian ({baseline_mean:.3f})')
    ax.fill_between(step_sizes, baseline_mean - baseline_std, baseline_mean + baseline_std,
                    color='#e74c3c', alpha=0.2)
    
    ax.set_xlabel('Dual Step Size')
    ax.set_ylabel('HV Ratio')
    ax.set_xscale('log')
    ax.set_title('Lagrangian Performance vs Dual Step Size')
    ax.legend()
    ax.grid(alpha=0.3)
    
    # Mark best step size
    best_idx = np.argmax(means)
    ax.scatter([step_sizes[best_idx]], [means[best_idx]], s=200, 
               facecolors='none', edgecolors='#2ecc71', linewidth=3, zorder=5)
    ax.annotate(f'Best: {step_sizes[best_idx]}', 
                xy=(step_sizes[best_idx], means[best_idx]),
                xytext=(step_sizes[best_idx]*2, means[best_idx]+0.02),
                fontsize=10)
    
    plt.tight_layout()
    fig.savefig(f'{save_dir}/stress_step_size.png', dpi=150, bbox_inches='tight')
    fig.savefig(f'{save_dir}/stress_step_size.pdf', bbox_inches='tight')
    print(f"Saved: stress_step_size.png/pdf")
    plt.show()


# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def compute_hv_simple(results, problem_type):
    """Simplified HV computation."""
    objectives = [r[1] for r in results if r[1] is not None]
    
    if problem_type == 'TSP':
        objectives = [o for o in objectives if all(x != float('inf') for x in o)]
        # For TSP, minimize both objectives
        ref_point = (15.0, 15.0)  # Should be larger than max distances
        if not objectives:
            return 0.0
        # Normalize and compute HV
        objs = np.array(objectives)
        # Simple dominated hypervolume approximation
        hv = 0.0
        for obj in objs:
            if all(obj[i] < ref_point[i] for i in range(len(obj))):
                box_vol = np.prod([ref_point[i] - obj[i] for i in range(len(obj))])
                hv = max(hv, box_vol)
        # Normalize by max possible
        max_hv = np.prod(ref_point)
        return hv / max_hv if max_hv > 0 else 0.0
    else:
        objectives = [o for o in objectives if all(x != -float('inf') for x in o)]
        if not objectives:
            return 0.0
        # For Knapsack, maximize both objectives
        objs = np.array(objectives)
        # Simple metric: average objective value (normalized)
        return np.mean(objs) / 30.0  # Rough normalization


def compute_front_disagreement(objectives, tol=1e-6):
    """Measures front fragmentation/inconsistency."""
    if len(objectives) < 3:
        return 0.0
    
    objs = np.array(objectives)
    
    # Pairwise distances
    dists = []
    for i in range(len(objs)):
        for j in range(i + 1, len(objs)):
            dists.append(np.linalg.norm(objs[i] - objs[j]))
    
    dists = np.array(dists)
    diameter = np.linalg.norm(objs.max(axis=0) - objs.min(axis=0)) + tol
    return float(np.std(dists / diameter))


# ============================================================================
# SUMMARY TABLE
# ============================================================================

def generate_summary_table(all_results, save_dir='.'):
    """Generate LaTeX summary table from all stress tests."""
    
    print("\n" + "=" * 70)
    print("SUMMARY: When Does Lagrangian Coordination Help?")
    print("=" * 70)
    
    # Build summary
    findings = []
    
    if 'overlap' in all_results:
        res = all_results['overlap']
        overlaps = sorted(res.keys())
        improvements = [np.mean(res[ov]['with']) - np.mean(res[ov]['without']) for ov in overlaps]
        trend = "increases" if improvements[-1] > improvements[0] else "decreases"
        findings.append(f"• Overlap ratio: Lagrangian benefit {trend} with overlap (Δ from {improvements[0]:.4f} to {improvements[-1]:.4f})")
    
    if 'subproblems' in all_results:
        res = all_results['subproblems']
        decomps = sorted(res.keys(), reverse=True)
        improvements = [np.mean(res[ds]['with']) - np.mean(res[ds]['without']) for ds in decomps]
        best_improvement = max(improvements)
        findings.append(f"• Subproblems: Max improvement {best_improvement:.4f} with more subproblems")
    
    if 'conflict' in all_results:
        res = all_results['conflict']
        conflicts = sorted(res.keys())
        improvements = [np.mean(res[cr]['with']) - np.mean(res[cr]['without']) for cr in conflicts]
        findings.append(f"• Conflict: Improvement ranges from {min(improvements):.4f} to {max(improvements):.4f}")
    
    for f in findings:
        print(f)
    
    print("\n--- LaTeX Table ---")
    print(r"""
\begin{table}[h]
\caption{Stress Test Summary: Lagrangian Coordination Benefits}
\label{tab:stress-summary}
\centering
\begin{tabular}{lcc}
\toprule
Stress Test & Condition & Lagrangian Improvement \\
\midrule
Overlap Ratio & 80\% overlap & +X.XXX \\
Subproblems & 10+ subproblems & +X.XXX \\
Conflict Intensity & 60\% conflict & +X.XXX \\
\bottomrule
\end{tabular}
\end{table}
""")


# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Lagrangian Coordination Stress Tests')
    parser.add_argument('--test', type=str, default='all', 
                        choices=['all', 'overlap', 'subproblems', 'conflict', 'dynamics', 'stepsize'],
                        help='Which stress test to run')
    parser.add_argument('--problem', type=str, default='TSP', choices=['TSP', 'Knapsack'])
    parser.add_argument('--size', type=int, default=50)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--save_dir', type=str, default='.')
    
    args = parser.parse_args()
    
    all_results = {}
    
    if args.test in ['all', 'overlap']:
        print("\n" + "="*70)
        print("RUNNING: Overlap Ratio Stress Test")
        print("="*70)
        results, decomp_size = stress_test_overlap_ratio(args.problem, args.size, args.runs)
        all_results['overlap'] = results
        plot_overlap_stress_test(results, decomp_size, args.save_dir)
    
    if args.test in ['all', 'subproblems']:
        print("\n" + "="*70)
        print("RUNNING: Subproblems Stress Test")
        print("="*70)
        results = stress_test_n_subproblems(args.problem, 100, args.runs)  # Larger problem
        all_results['subproblems'] = results
        plot_subproblems_stress_test(results, args.save_dir)
    
    if args.test in ['all', 'conflict']:
        print("\n" + "="*70)
        print("RUNNING: Conflict Intensity Stress Test")
        print("="*70)
        results = stress_test_conflict_intensity(args.problem, args.size, args.runs)
        all_results['conflict'] = results
        plot_conflict_stress_test(results, args.save_dir)
    
    if args.test in ['all', 'dynamics']:
        print("\n" + "="*70)
        print("RUNNING: Dynamics Stress Test")
        print("="*70)
        results, checkpoints = stress_test_dynamics(args.problem, args.size, min(args.runs, 3))
        all_results['dynamics'] = (results, checkpoints)
        plot_dynamics_stress_test(results, checkpoints, args.save_dir)
    
    if args.test in ['all', 'stepsize']:
        print("\n" + "="*70)
        print("RUNNING: Step Size Sensitivity Test")
        print("="*70)
        results, baseline = stress_test_dual_step_size(args.problem, args.size, args.runs)
        all_results['stepsize'] = (results, baseline)
        plot_step_size_stress_test(results, baseline, args.save_dir)
    
    # Generate summary
    if args.test == 'all':
        generate_summary_table(all_results, args.save_dir)
    
    print("\n✅ Stress tests complete!")