
"""
Proper Lagrangian Coordination Ablation Tests (Multi-Seed Version)
===================================================================

Key improvements:
1. 10 seeds with PAIRED comparison (same seed for with/without)
2. Reports mean ± std 
3. Shaded confidence regions in plots
4. Statistical significance test (paired t-test)

Usage:
    python lagrangian_ablation_multiseed.py --test all --problem TSP --size 50

Author: Fixed ablation design with proper statistics
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from dataclasses import dataclass
import torch
import random
import sys
from scipy import stats

# Adjust path as needed

from project_MOCO.MOCO_supplementary.our_method_dl_UCB_variant import (
    FixedDecomposedGameOptUCBHedge,
    AdvancedDecompositionWrapper
)

# ============================================================================
# GLOBAL CONFIGURATION
# ============================================================================

N_SEEDS = 10  # Number of seeds for averaging
BASE_SEED = 250  # Starting seed 5, 10, 200-best, 250


def set_all_seeds(seed: int):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


# ============================================================================
# PROBLEM 1: TSP with Subproblem Boundary Conflicts
# ============================================================================

class SubproblemConflictTSP:
    """
    TSP designed to create conflicts at SUBPROBLEM BOUNDARIES.
    
    Design principle:
    - Divide cities into "regions" matching expected subproblem boundaries
    - Within each region, create clear local structure
    - At boundaries, create "contested" cities where:
      * Left subproblem prefers connecting to city A
      * Right subproblem prefers connecting to city B
      * A ≠ B creates disagreement on the boundary position
    
    This directly stresses the Lagrangian coordination mechanism.
    """
    
    def __init__(self, n_cities: int, n_regions: int = 5, conflict_strength: float = 2.0, seed: int = None):
        """
        Args:
            n_cities: Total number of cities
            n_regions: Number of regions (should match expected # subproblems)
            conflict_strength: How much to bias boundary preferences (higher = more conflict)
            seed: Random seed for problem generation
        """
        if seed is not None:
            set_all_seeds(seed)
            
        self.n_cities = n_cities
        self.n_regions = n_regions
        self.conflict_strength = conflict_strength
        self._m_objectives = 2
        
        # Assign cities to regions
        self.cities_per_region = n_cities // n_regions
        self.region_assignments = self._assign_regions()
        self.boundary_cities = self._identify_boundaries()
        
        # Create distance matrices with boundary conflicts
        self.distances1 = self._create_conflict_distances(objective=1)
        self.distances2 = self._create_conflict_distances(objective=2)
        self.distance_matrices = [self.distances1, self.distances2]
        
        print(f"Created SubproblemConflictTSP: {n_cities} cities, {n_regions} regions")
        print(f"  Boundary cities: {self.boundary_cities}")
        print(f"  Conflict strength: {conflict_strength}")
    
    def _assign_regions(self) -> Dict[int, int]:
        """Assign each city to a region."""
        assignments = {}
        for i in range(self.n_cities):
            region = min(i // self.cities_per_region, self.n_regions - 1)
            assignments[i] = region
        return assignments
    
    def _identify_boundaries(self) -> List[int]:
        """Identify cities at region boundaries."""
        boundaries = []
        for i in range(1, self.n_cities):
            if self.region_assignments[i] != self.region_assignments[i-1]:
                boundaries.append(i)
                boundaries.append(i-1)
        return list(set(boundaries))
    
    def _create_conflict_distances(self, objective: int) -> np.ndarray:
        """
        Create distance matrix with conflicts at boundaries.
        
        For boundary city i between regions A and B:
        - Objective 1: Short distance to "left-preferred" neighbors
        - Objective 2: Short distance to "right-preferred" neighbors
        
        This means optimizing either objective creates different local structure.
        """
        # Start with base grid coordinates
        coords = np.zeros((self.n_cities, 2))
        for i in range(self.n_cities):
            region = self.region_assignments[i]
            local_idx = i - region * self.cities_per_region
            # Arrange in a line with clusters
            coords[i, 0] = region * 10 + local_idx * 0.5
            coords[i, 1] = np.sin(i * 0.5) * 2  # Some vertical spread
        
        # Compute base distances
        dist = np.zeros((self.n_cities, self.n_cities))
        for i in range(self.n_cities):
            for j in range(self.n_cities):
                if i != j:
                    dist[i, j] = np.linalg.norm(coords[i] - coords[j])
        
        # Add conflict structure at boundaries
        for boundary_city in self.boundary_cities:
            region = self.region_assignments[boundary_city]
            
            # Find neighbors in left region and right region
            left_neighbors = [c for c in range(self.n_cities) 
                            if self.region_assignments[c] == region - 1]
            right_neighbors = [c for c in range(self.n_cities) 
                             if self.region_assignments[c] == region + 1]
            
            if not left_neighbors or not right_neighbors:
                continue
            
            # Objective 1 prefers left connections, objective 2 prefers right
            if objective == 1:
                # Make left neighbors closer, right neighbors farther
                for ln in left_neighbors[:2]:  # Closest 2
                    dist[boundary_city, ln] /= self.conflict_strength
                    dist[ln, boundary_city] /= self.conflict_strength
                for rn in right_neighbors[:2]:
                    dist[boundary_city, rn] *= self.conflict_strength
                    dist[rn, boundary_city] *= self.conflict_strength
            else:
                # Make right neighbors closer, left neighbors farther
                for rn in right_neighbors[:2]:
                    dist[boundary_city, rn] /= self.conflict_strength
                    dist[rn, boundary_city] /= self.conflict_strength
                for ln in left_neighbors[:2]:
                    dist[boundary_city, ln] *= self.conflict_strength
                    dist[ln, boundary_city] *= self.conflict_strength
        
        return dist
    
    def evaluate(self, solution: List[int]) -> Tuple[float, float]:
        """Evaluate tour length for both objectives."""
        if len(solution) != self.n_cities:
            return (float('inf'), float('inf'))
        
        obj1 = sum(self.distances1[solution[i], solution[(i+1) % self.n_cities]] 
                   for i in range(self.n_cities))
        obj2 = sum(self.distances2[solution[i], solution[(i+1) % self.n_cities]] 
                   for i in range(self.n_cities))
        
        return (obj1, obj2)
    
    def random_solution(self) -> List[int]:
        tour = list(range(self.n_cities))
        np.random.shuffle(tour)
        return tour
    
    @property
    def num_objectives(self) -> int:
        return 2


# ============================================================================
# PROBLEM 2: Knapsack with Correlated Items in Overlap Regions
# ============================================================================

class SubproblemConflictKnapsack:
    """
    Knapsack where items in overlap regions have complex dependencies.
    
    Design principle:
    - Items are grouped into "clusters" matching expected subproblems
    - Items at cluster boundaries have:
      * High value when combined with LEFT cluster items
      * Low value when combined with RIGHT cluster items (or vice versa)
    - This creates preference conflicts for boundary items
    """
    
    def __init__(self, n_items: int, n_clusters: int = 5, 
                 conflict_strength: float = 2.0, capacity: float = None, seed: int = None):
        
        if seed is not None:
            set_all_seeds(seed)
            
        self.n_items = n_items
        self.n_clusters = n_clusters
        self.conflict_strength = conflict_strength
        self._n_objectives = 2
        
        # Default capacity
        if capacity is None:
            capacity = n_items * 0.25
        self.capacity = capacity
        
        # Assign items to clusters
        self.items_per_cluster = n_items // n_clusters
        self.cluster_assignments = {i: min(i // self.items_per_cluster, n_clusters - 1) 
                                   for i in range(n_items)}
        self.boundary_items = self._identify_boundaries()
        
        # Generate weights (uniform)
        self.weights = np.random.uniform(1, 3, size=n_items)
        
        # Generate values with boundary conflicts
        self.values1 = self._create_conflict_values(objective=1)
        self.values2 = self._create_conflict_values(objective=2)
        
        print(f"Created SubproblemConflictKnapsack: {n_items} items, {n_clusters} clusters")
        print(f"  Boundary items: {self.boundary_items}")
        print(f"  Capacity: {capacity}")
    
    def _identify_boundaries(self) -> List[int]:
        """Identify items at cluster boundaries."""
        boundaries = []
        for i in range(1, self.n_items):
            if self.cluster_assignments[i] != self.cluster_assignments[i-1]:
                boundaries.append(i)
                boundaries.append(i-1)
        return list(set(boundaries))
    
    def _create_conflict_values(self, objective: int) -> np.ndarray:
        """
        Create values with conflicts at boundaries.
        
        Boundary items have high value for one objective when combined
        with their "preferred" cluster neighbors.
        """
        base_values = np.random.uniform(5, 15, size=self.n_items)
        values = base_values.copy()
        
        for item in self.boundary_items:
            cluster = self.cluster_assignments[item]
            
            if objective == 1:
                # High value if in left-leaning cluster boundary
                if cluster > 0 and item == cluster * self.items_per_cluster:
                    values[item] *= self.conflict_strength
                else:
                    values[item] /= self.conflict_strength
            else:
                # High value if in right-leaning cluster boundary
                if cluster < self.n_clusters - 1 and item == (cluster + 1) * self.items_per_cluster - 1:
                    values[item] *= self.conflict_strength
                else:
                    values[item] /= self.conflict_strength
        
        return values
    
    def evaluate(self, selection: List[int]) -> Tuple[float, float]:
        """Evaluate selection. selection is binary vector or list of indices."""
        # Handle both formats
        if len(selection) == self.n_items and all(s in [0, 1] for s in selection):
            # Binary vector
            selected_indices = [i for i, s in enumerate(selection) if s == 1]
        else:
            # List of indices
            selected_indices = selection
        
        total_weight = sum(self.weights[i] for i in selected_indices)
        
        if total_weight > self.capacity:
            return (-float('inf'), -float('inf'))
        
        obj1 = sum(self.values1[i] for i in selected_indices)
        obj2 = sum(self.values2[i] for i in selected_indices)
        
        return (obj1, obj2)
    
    def random_solution(self) -> List[int]:
        selection = []
        total_weight = 0
        items = list(range(self.n_items))
        np.random.shuffle(items)
        
        for item in items:
            if total_weight + self.weights[item] <= self.capacity:
                selection.append(item)
                total_weight += self.weights[item]
        
        return selection
    
    @property
    def num_objectives(self) -> int:
        return 2


# ============================================================================
# INSTRUMENTED OPTIMIZER: Track Disagreement Metrics
# ============================================================================

class InstrumentedOptimizer(FixedDecomposedGameOptUCBHedge):
    """
    Extended optimizer that tracks coordination metrics.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Tracking metrics
        self.disagreement_history = []
        self.variance_history = []
        self.dual_var_history = []
        self.reward_history = []
        self.overlap_positions = self._get_overlap_positions()
    
    def _get_overlap_positions(self) -> List[int]:
        """Get positions that appear in multiple subproblems."""
        return [pos for pos in range(self.problem_size) 
                if len(self.position_to_subproblems[pos]) > 1]
    
    def measure_disagreement(self, solution: List[int]) -> float:
        """
        Measure how much subproblems "disagree" on overlapping positions.
        
        For each overlap position, compute variance of value estimates
        weighted by subproblem membership.
        """
        if not self.overlap_positions:
            return 0.0
        
        total_disagreement = 0.0
        
        for pos in self.overlap_positions:
            # Get value estimates for all possible values at this position
            value_estimates = self.value_estimates[pos].numpy()
            
            # Disagreement = variance of estimates (high variance = uncertainty = potential conflict)
            var = np.var(value_estimates)
            total_disagreement += var
        
        return total_disagreement / len(self.overlap_positions)
    
    def measure_overlap_variance(self) -> float:
        """Average variance of value estimates at overlapping positions."""
        if not self.overlap_positions:
            return 0.0
        
        variances = []
        for pos in self.overlap_positions:
            var = torch.var(self.value_estimates[pos]).item()
            variances.append(var)
        
        return np.mean(variances)
    
    def update_global_parameters(self, solution: List[int], reward: float):
        """Override to track metrics."""
        # Track BEFORE update
        self.disagreement_history.append(self.measure_disagreement(solution))
        self.variance_history.append(self.measure_overlap_variance())
        self.reward_history.append(reward)
        
        if self.use_lagrangian:
            self.dual_var_history.append(self.dual_vars.mean().item())
        
        # Call parent update
        super().update_global_parameters(solution, reward)
    
    def get_coordination_metrics(self) -> Dict:
        """Return all tracked metrics."""
        return {
            'disagreement': self.disagreement_history,
            'variance': self.variance_history,
            'dual_vars': self.dual_var_history,
            'rewards': self.reward_history,
            'n_overlap_positions': len(self.overlap_positions),
            'overlap_positions': self.overlap_positions,
        }


# ============================================================================
# STATISTICAL UTILITIES
# ============================================================================

def paired_ttest(with_results: List[float], without_results: List[float]) -> Tuple[float, float]:
    """
    Perform paired t-test.
    
    Returns:
        t_statistic, p_value
    """
    differences = [w - wo for w, wo in zip(with_results, without_results)]
    t_stat, p_val = stats.ttest_1samp(differences, 0)
    return t_stat, p_val


def compute_confidence_interval(data: List[float], confidence: float = 0.95) -> Tuple[float, float, float]:
    """
    Compute mean and confidence interval.
    
    Returns:
        mean, lower_bound, upper_bound
    """
    n = len(data)
    mean = np.mean(data)
    std_err = stats.sem(data)
    h = std_err * stats.t.ppf((1 + confidence) / 2, n - 1)
    return mean, mean - h, mean + h


def print_statistics_table(results: Dict, metric_name: str = "Reward"):
    """Simple statistics - mean, std, improvement."""
    
    with_vals = results['with']
    without_vals = results['without']
    
    mean_with = np.mean(with_vals)
    std_with = np.std(with_vals)
    mean_without = np.mean(without_vals)
    std_without = np.std(without_vals)
    
    # Improvement
    mean_delta = mean_with - mean_without
    
    print(f"\n--- Summary: {metric_name} ---")
    print(f"With Lagrangian:    {mean_with:.4f} ± {std_with:.4f}")
    print(f"Without Lagrangian: {mean_without:.4f} ± {std_without:.4f}")
    print(f"Mean improvement:   {mean_delta:+.4f}")
    
    # Win rate
    wins = sum(1 for w, wo in zip(with_vals, without_vals) if w > wo)
    print(f"Win rate: {wins}/{len(with_vals)} ({100*wins/len(with_vals):.0f}%)")
    
    # Per-seed breakdown
    print(f"\nPer-Seed Results:")
    print(f"{'Seed':<6} {'With':>12} {'Without':>12} {'Δ':>12}")
    print("-" * 45)
    for i, (w, wo) in enumerate(zip(with_vals, without_vals)):
        print(f"{i:<6} {w:>12.4f} {wo:>12.4f} {w - wo:>+12.4f}")
    print("-" * 45)


def print_coordination_metrics(metrics_with: List[Dict], metrics_without: List[Dict]):
    """
    Print the CORRECT metrics for Lagrangian ablation:
    - Final overlap variance (should be LOWER with Lagrangian)
    - Rate of variance decrease (should be FASTER with Lagrangian)
    """
    
    # Extract final variance values from each run
    final_var_with = [m['variance'][-1] if m['variance'] else 0 for m in metrics_with]
    final_var_without = [m['variance'][-1] if m['variance'] else 0 for m in metrics_without]
    
    # Extract final disagreement values
    final_dis_with = [m['disagreement'][-1] if m['disagreement'] else 0 for m in metrics_with]
    final_dis_without = [m['disagreement'][-1] if m['disagreement'] else 0 for m in metrics_without]
    
    # Compute area under variance curve (lower = faster convergence)
    auc_var_with = [np.trapz(m['variance']) if m['variance'] else 0 for m in metrics_with]
    auc_var_without = [np.trapz(m['variance']) if m['variance'] else 0 for m in metrics_without]
    
    print("\n" + "=" * 70)
    print("COORDINATION METRICS (The Lagrangian Effect)")
    print("=" * 70)
    
    # 1. Final overlap variance
    mean_final_var_with = np.mean(final_var_with)
    mean_final_var_without = np.mean(final_var_without)
    t_stat, p_val = stats.ttest_rel(final_var_with, final_var_without)
    reduction = (mean_final_var_without - mean_final_var_with) / mean_final_var_without * 100 if mean_final_var_without > 0 else 0
    
    print(f"\n1. FINAL OVERLAP VARIANCE (lower = better coordination)")
    print(f"   With Lagrangian:    {mean_final_var_with:.6f} ± {np.std(final_var_with):.6f}")
    print(f"   Without Lagrangian: {mean_final_var_without:.6f} ± {np.std(final_var_without):.6f}")
    print(f"   Reduction: {reduction:+.1f}%")
    print(f"   Paired t-test: t = {t_stat:.4f}, p = {p_val:.6f}")
    print(f"   {'*** SIGNIFICANT' if p_val < 0.05 else 'not significant'}")
    
    # 2. Final disagreement
    mean_final_dis_with = np.mean(final_dis_with)
    mean_final_dis_without = np.mean(final_dis_without)
    t_stat_dis, p_val_dis = stats.ttest_rel(final_dis_with, final_dis_without)
    reduction_dis = (mean_final_dis_without - mean_final_dis_with) / mean_final_dis_without * 100 if mean_final_dis_without > 0 else 0
    
    print(f"\n2. FINAL DISAGREEMENT SCORE (lower = better consensus)")
    print(f"   With Lagrangian:    {mean_final_dis_with:.6f} ± {np.std(final_dis_with):.6f}")
    print(f"   Without Lagrangian: {mean_final_dis_without:.6f} ± {np.std(final_dis_without):.6f}")
    print(f"   Reduction: {reduction_dis:+.1f}%")
    print(f"   Paired t-test: t = {t_stat_dis:.4f}, p = {p_val_dis:.6f}")
    print(f"   {'*** SIGNIFICANT' if p_val_dis < 0.05 else 'not significant'}")
    
    # 3. Area under variance curve (convergence speed)
    mean_auc_with = np.mean(auc_var_with)
    mean_auc_without = np.mean(auc_var_without)
    t_stat_auc, p_val_auc = stats.ttest_rel(auc_var_with, auc_var_without)
    reduction_auc = (mean_auc_without - mean_auc_with) / mean_auc_without * 100 if mean_auc_without > 0 else 0
    
    print(f"\n3. VARIANCE AUC (lower = faster convergence)")
    print(f"   With Lagrangian:    {mean_auc_with:.4f} ± {np.std(auc_var_with):.4f}")
    print(f"   Without Lagrangian: {mean_auc_without:.4f} ± {np.std(auc_var_without):.4f}")
    print(f"   Reduction: {reduction_auc:+.1f}%")
    print(f"   Paired t-test: t = {t_stat_auc:.4f}, p = {p_val_auc:.6f}")
    print(f"   {'*** SIGNIFICANT' if p_val_auc < 0.05 else 'not significant'}")
    
    print("\n" + "=" * 70)
    
    # Per-seed breakdown
    print("\nPer-Seed Final Variance:")
    print(f"{'Seed':<6} {'With':>12} {'Without':>12} {'Δ':>12}")
    print("-" * 45)
    for i, (w, wo) in enumerate(zip(final_var_with, final_var_without)):
        print(f"{i:<6} {w:>12.6f} {wo:>12.6f} {w - wo:>+12.6f}")
    print("-" * 45)
    
    return {
        'final_var_with': final_var_with,
        'final_var_without': final_var_without,
        'final_var_reduction': reduction,
        'final_var_p': p_val,
        'final_dis_with': final_dis_with,
        'final_dis_without': final_dis_without,
        'final_dis_reduction': reduction_dis,
        'final_dis_p': p_val_dis,
        'auc_with': auc_var_with,
        'auc_without': auc_var_without,
        'auc_reduction': reduction_auc,
        'auc_p': p_val_auc,
    }


# ============================================================================
# TEST 1: Direct Optimizer Comparison (Single Scalarization)
# ============================================================================

def test_single_scalarization(problem_class, problem_kwargs, weights, n_seeds=N_SEEDS, max_iters=400):
    """
    Test Lagrangian benefit on a SINGLE scalarization with proper multi-seed averaging.
    
    CRITICAL: Same seed for both with/without conditions (paired comparison)
    """
    
    print("\n" + "=" * 70)
    print("TEST: Single Scalarization Comparison (Paired Seeds)")
    print(f"Weights: {weights}")
    print(f"Number of seeds: {n_seeds}")
    print("=" * 70)
    
    results = {
        'with': [], 
        'without': [], 
        'metrics_with': [], 
        'metrics_without': [],
        'seeds': []
    }
    
    for seed_idx in range(n_seeds):
        seed = BASE_SEED + seed_idx
        results['seeds'].append(seed)
        print(f"\n--- Seed {seed} ({seed_idx + 1}/{n_seeds}) ---")
        
        # CRITICAL: Create problem with THIS seed
        set_all_seeds(seed)
        problem = problem_class(**problem_kwargs, seed=seed)
        problem_size = problem.n_cities if hasattr(problem, 'n_cities') else problem.n_items
        
        # Create scalarized evaluation function
        def evaluate_fn(solution):
            obj = problem.evaluate(solution)
            if any(o == float('inf') or o == -float('inf') for o in obj):
                return -1e10
            # For TSP: minimize, so negate
            if hasattr(problem, 'n_cities'):
                return -sum(w * o for w, o in zip(weights, obj))
            else:
                return sum(w * o for w, o in zip(weights, obj))
        
        # Shared parameters
        base_params = {
            'problem_size': problem_size,
            'evaluate_fn': evaluate_fn,
            'decomposition_size': 10,
            'overlap': 4,
            'max_iterations': max_iters,
            'patience': 9999,
            'use_ftrl': True,
            'use_diminishing_overlap': False,
            'dual_step_size': 0.1,  # Controls how aggressively dual vars track violations
            'overlap_decay_rate': 0.1,
        }
        
        # CRITICAL: Generate initial solution with this seed
        initial = problem.random_solution()
        
        # WITH Lagrangian (no extra seed reset - let randomness flow naturally)
        opt_with = InstrumentedOptimizer(
            **base_params,
            use_lagrangian=True,
            use_accelerated_dual=True,
        )
        best_sol_with, best_reward_with = opt_with.optimize(initial_solution=initial.copy())
        metrics_with = opt_with.get_coordination_metrics()
        
        # WITHOUT Lagrangian (same initial solution, sequential randomness)
        opt_without = InstrumentedOptimizer(
            **base_params,
            use_lagrangian=False,
            use_accelerated_dual=False,
        )
        best_sol_without, best_reward_without = opt_without.optimize(initial_solution=initial.copy())
        metrics_without = opt_without.get_coordination_metrics()
        
        results['with'].append(best_reward_with)
        results['without'].append(best_reward_without)
        results['metrics_with'].append(metrics_with)
        results['metrics_without'].append(metrics_without)
        
        print(f"  With Lagrangian:    {best_reward_with:.4f}")
        print(f"  Without Lagrangian: {best_reward_without:.4f}")
        print(f"  Δ = {best_reward_with - best_reward_without:+.4f}")
        print(f"  Overlap positions: {metrics_with['n_overlap_positions']}")
    
    # Print statistics
    print_statistics_table(results, metric_name="Best Reward")
    
    return results


# ============================================================================
# TEST 2: Vary Overlap Amount
# ============================================================================

def test_vary_overlap(problem_class, problem_kwargs, weights, overlap_values, n_seeds=N_SEEDS, max_iters=400):
    """
    Test hypothesis: More overlap → more Lagrangian benefit.
    
    For each seed, we reset to get SAME problem/initial across all overlaps.
    This isolates the effect of overlap size.
    """
    
    print("\n" + "=" * 70)
    print("TEST: Varying Overlap Amount")
    print(f"Overlap values: {overlap_values}")
    print(f"Number of seeds: {n_seeds}")
    print("=" * 70)
    
    results = {ov: {'with': [], 'without': [], 'seeds': []} for ov in overlap_values}
    
    for overlap in overlap_values:
        print(f"\n{'='*50}")
        print(f"Overlap: {overlap}")
        print(f"{'='*50}")
        
        for seed_idx in range(n_seeds):
            seed = BASE_SEED + seed_idx
            
            # Reset seed for EACH (overlap, seed) combination
            # This ensures same problem/initial across all overlaps for same seed
            set_all_seeds(seed)
            
            problem = problem_class(**problem_kwargs, seed=seed)
            problem_size = problem.n_cities if hasattr(problem, 'n_cities') else problem.n_items
            decomp_size = 10
            
            if overlap >= decomp_size:
                print(f"  Skipping overlap {overlap} >= decomp_size {decomp_size}")
                continue
            
            results[overlap]['seeds'].append(seed)
            
            def evaluate_fn(solution):
                obj = problem.evaluate(solution)
                if any(o == float('inf') or o == -float('inf') for o in obj):
                    return -1e10
                if hasattr(problem, 'n_cities'):
                    return -sum(w * o for w, o in zip(weights, obj))
                else:
                    return sum(w * o for w, o in zip(weights, obj))
            
            base_params = {
                'problem_size': problem_size,
                'evaluate_fn': evaluate_fn,
                'decomposition_size': decomp_size,
                'overlap': overlap,
                'max_iterations': max_iters,
                'patience': 9999,
                'use_ftrl': True,
                'use_diminishing_overlap': False,
                'dual_step_size': 0.1,
                'overlap_decay_rate': 0.1,
            }
            
            # Same initial solution (from the seed reset above)
            initial = problem.random_solution()
            
            # With Lagrangian
            opt = InstrumentedOptimizer(**base_params, use_lagrangian=True, use_accelerated_dual=True)
            _, reward_with = opt.optimize(initial_solution=initial.copy())
            results[overlap]['with'].append(reward_with)
            
            # Without Lagrangian (sequential randomness, no seed reset)
            opt = InstrumentedOptimizer(**base_params, use_lagrangian=False, use_accelerated_dual=False)
            _, reward_without = opt.optimize(initial_solution=initial.copy())
            results[overlap]['without'].append(reward_without)
            
            print(f"  Seed {seed}: With={reward_with:.4f}, Without={reward_without:.4f}, Δ={reward_with-reward_without:+.4f}")
        
        # Print stats for this overlap
        if results[overlap]['with']:
            print_statistics_table(results[overlap], metric_name=f"Reward (Overlap={overlap})")
    
    return results
    
    return results
    
    return results


# ============================================================================
# TEST 3: Track Coordination Dynamics
# ============================================================================

def test_coordination_dynamics(problem_class, problem_kwargs, weights, n_seeds=N_SEEDS, max_iters=400):
    """
    Track how coordination metrics evolve over iterations.
    
    Key hypothesis: With Lagrangian, variance at overlap positions 
    should decrease faster.
    """
    
    print("\n" + "=" * 70)
    print("TEST: Coordination Dynamics Over Time (Paired Seeds)")
    print(f"Number of seeds: {n_seeds}")
    print("=" * 70)
    
    all_metrics_with = []
    all_metrics_without = []
    all_rewards_with = []
    all_rewards_without = []
    
    for seed_idx in range(n_seeds):
        seed = BASE_SEED + seed_idx
        print(f"\n--- Seed {seed} ({seed_idx + 1}/{n_seeds}) ---")
        
        # Create problem with this seed
        set_all_seeds(seed)
        problem = problem_class(**problem_kwargs, seed=seed)
        problem_size = problem.n_cities if hasattr(problem, 'n_cities') else problem.n_items
        
        def evaluate_fn(solution):
            obj = problem.evaluate(solution)
            if any(o == float('inf') or o == -float('inf') for o in obj):
                return -1e10
            if hasattr(problem, 'n_cities'):
                return -sum(w * o for w, o in zip(weights, obj))
            else:
                return sum(w * o for w, o in zip(weights, obj))
        
        base_params = {
            'problem_size': problem_size,
            'evaluate_fn': evaluate_fn,
            'decomposition_size': 10,
            'overlap': 4,
            'max_iterations': max_iters,
            'patience': 9999,
            'use_ftrl': True,
            'use_diminishing_overlap': False,
            'dual_step_size': 0.1,  # Increased from 0.1
            'overlap_decay_rate': 0.1,
        }
        
        # Same initial solution
        initial = problem.random_solution()
        
        # With Lagrangian
        opt_with = InstrumentedOptimizer(**base_params, use_lagrangian=True, use_accelerated_dual=True)
        _, reward_with = opt_with.optimize(initial_solution=initial.copy())
        all_metrics_with.append(opt_with.get_coordination_metrics())
        all_rewards_with.append(reward_with)
        
        # Without Lagrangian (sequential randomness, no seed reset)
        opt_without = InstrumentedOptimizer(**base_params, use_lagrangian=False, use_accelerated_dual=False)
        _, reward_without = opt_without.optimize(initial_solution=initial.copy())
        all_metrics_without.append(opt_without.get_coordination_metrics())
        all_rewards_without.append(reward_without)
        
        print(f"  With: {reward_with:.4f}, Without: {reward_without:.4f}, Δ={reward_with - reward_without:+.4f}")
    
    # Print statistics
    print_statistics_table({'with': all_rewards_with, 'without': all_rewards_without}, 
                          metric_name="Best Reward (Dynamics Test)")
    
    return all_metrics_with, all_metrics_without


# ============================================================================
# PLOTTING WITH CONFIDENCE INTERVALS
# ============================================================================

def pad_and_aggregate(metrics_list: List[Dict], key: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Pad trajectories to same length and compute mean ± std.
    
    Returns:
        mean, std, x_values
    """
    trajectories = [m[key] for m in metrics_list if m[key]]
    if not trajectories:
        return np.array([]), np.array([]), np.array([])
    
    max_len = max(len(t) for t in trajectories)
    
    # Pad with last value
    padded = []
    for t in trajectories:
        if len(t) < max_len:
            padded.append(t + [t[-1]] * (max_len - len(t)))
        else:
            padded.append(t)
    
    padded = np.array(padded)
    mean = np.mean(padded, axis=0)
    std = np.std(padded, axis=0)
    x = np.arange(len(mean))
    
    return mean, std, x


def plot_single_scalarization_results(results, save_path='lagrangian_single_scalarization.png'):
    """Plot results - focus on REWARD comparison where the effect is visible."""
    
    fig, axes = plt.subplots(1, 4, figsize=(18, 5))
    
    n_seeds = len(results['with'])
    
    mean_with = np.mean(results['with'])
    mean_without = np.mean(results['without'])
    std_with = np.std(results['with'])
    std_without = np.std(results['without'])
    mean_delta = mean_with - mean_without
    
    # Win rate
    wins = sum(1 for w, wo in zip(results['with'], results['without']) if w > wo)
    win_rate = 100 * wins / n_seeds
    
    # Std reduction
    std_reduction = (std_without - std_with) / std_without * 100 if std_without > 0 else 0
    
    # ========== Plot 1: Bar Chart with Error Bars (SHOWS VARIANCE CLEARLY) ==========
    ax = axes[0]
    
    x_pos = [0, 1]
    means = [mean_with, mean_without]
    stds = [std_with, std_without]
    colors = ['#2ecc71', '#e74c3c']
    
    bars = ax.bar(x_pos, means, yerr=stds, capsize=10, color=colors, alpha=0.7,
                  error_kw={'linewidth': 2, 'capthick': 2})
    
    ax.set_xticks(x_pos)
    ax.set_xticklabels(['With\nLagrangian', 'Without\nLagrangian'])
    ax.set_ylabel('Best Reward', fontsize=12)
    ax.set_title(f'Mean ± Std (n={n_seeds})', fontsize=12)
    ax.grid(alpha=0.3, axis='y')
    
    # Annotate the std values on bars
    for i, (m, s) in enumerate(zip(means, stds)):
        ax.annotate(f'σ={s:.1f}', xy=(i, m + s + 2), ha='center', fontsize=10, fontweight='bold')
    
    # Delta annotation
    ax.text(0.5, 0.02, f'Δ = {mean_delta:+.1f}, Std reduction: {std_reduction:.0f}%', 
            transform=ax.transAxes, ha='center', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    # ========== Plot 2: Individual Points (Shows Distribution) ==========
    ax = axes[1]
    
    # Scatter with jitter
    for i, (w, wo) in enumerate(zip(results['with'], results['without'])):
        jitter_w = (np.random.rand() - 0.5) * 0.3
        jitter_wo = (np.random.rand() - 0.5) * 0.3
        ax.scatter(0 + jitter_w, w, color='#2ecc71', s=50, alpha=0.7, edgecolor='darkgreen')
        ax.scatter(1 + jitter_wo, wo, color='#e74c3c', s=50, alpha=0.7, edgecolor='darkred')
    
    # Add mean lines
    ax.hlines(mean_with, -0.3, 0.3, colors='darkgreen', linewidth=2, label=f'Mean: {mean_with:.1f}')
    ax.hlines(mean_without, 0.7, 1.3, colors='darkred', linewidth=2, label=f'Mean: {mean_without:.1f}')
    
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['With\nLagrangian', 'Without\nLagrangian'])
    ax.set_ylabel('Best Reward', fontsize=12)
    ax.set_title(f'Individual Seeds\nWin rate: {win_rate:.0f}%', fontsize=12)
    ax.grid(alpha=0.3, axis='y')
    
    # ========== Plot 3: Variance at Overlaps Over Time ==========
    ax = axes[2]
    
    for metrics_list, label, color in [
        (results['metrics_with'], 'With Lagrangian', '#2ecc71'),
        (results['metrics_without'], 'Without Lagrangian', '#e74c3c')
    ]:
        mean_var, std_var, x = pad_and_aggregate(metrics_list, 'variance')
        if len(mean_var) > 0:
            ax.plot(x, mean_var, label=label, color=color, linewidth=2)
            ax.fill_between(x, mean_var - std_var, mean_var + std_var, alpha=0.2, color=color)
    
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Variance at Overlap Positions', fontsize=12)
    ax.set_title('Coordination Quality Over Time', fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    # ========== Plot 4: Dual Variable Evolution ==========
    ax = axes[3]
    
    # Individual trajectories (light)
    for metrics in results['metrics_with']:
        if metrics['dual_vars']:
            ax.plot(metrics['dual_vars'], alpha=0.2, color='#3498db', linewidth=1)
    
    # Mean trajectory (bold)
    mean_dual, std_dual, x = pad_and_aggregate(results['metrics_with'], 'dual_vars')
    if len(mean_dual) > 0:
        ax.plot(x, mean_dual, color='#2c3e50', linewidth=2.5, label='Mean ± Std')
        ax.fill_between(x, mean_dual - std_dual, mean_dual + std_dual, alpha=0.3, color='#2c3e50')
    
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Mean Dual Variable', fontsize=12)
    ax.set_title('Dual Variable Evolution', fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.savefig(save_path.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"\nSaved: {save_path}")
    plt.show()


def plot_overlap_results(results, save_path='lagrangian_overlap_test.png'):
    """Plot results from overlap variation test with confidence intervals."""
    
    overlaps = sorted([ov for ov in results.keys() if results[ov]['with']])
    
    if not overlaps:
        print("No valid overlap results to plot!")
        return
    
    means_with = [np.mean(results[ov]['with']) for ov in overlaps]
    means_without = [np.mean(results[ov]['without']) for ov in overlaps]
    stds_with = [np.std(results[ov]['with']) for ov in overlaps]
    stds_without = [np.std(results[ov]['without']) for ov in overlaps]
    
    # Compute improvements with confidence intervals
    improvements = []
    improvement_cis = []
    for ov in overlaps:
        diffs = [w - wo for w, wo in zip(results[ov]['with'], results[ov]['without'])]
        improvements.append(np.mean(diffs))
        improvement_cis.append(np.std(diffs))
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # ========== Plot 1: Absolute performance ==========
    ax = axes[0]
    ax.errorbar(overlaps, means_with, yerr=stds_with, marker='o', capsize=5,
                label='With Lagrangian', color='#2ecc71', linewidth=2, markersize=8)
    ax.errorbar(overlaps, means_without, yerr=stds_without, marker='s', capsize=5,
                label='Without Lagrangian', color='#e74c3c', linewidth=2, markersize=8)
    ax.set_xlabel('Overlap Size', fontsize=12)
    ax.set_ylabel('Best Reward', fontsize=12)
    ax.set_title('Performance vs Overlap (mean ± std)', fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    # ========== Plot 2: Improvement with significance ==========
    ax = axes[1]
    colors = ['#2ecc71' if imp > 0 else '#e74c3c' for imp in improvements]
    
    bars = ax.bar(overlaps, improvements, yerr=improvement_cis, capsize=5,
                  color=colors, alpha=0.8, width=0.8)
    ax.axhline(y=0, color='gray', linestyle='--', linewidth=1.5)
    
    # Add significance stars
    for i, ov in enumerate(overlaps):
        diffs = [w - wo for w, wo in zip(results[ov]['with'], results[ov]['without'])]
        _, p_val = stats.ttest_1samp(diffs, 0)
        sig_str = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else ''
        if sig_str:
            y_pos = improvements[i] + improvement_cis[i] * 1.1 if improvements[i] > 0 else improvements[i] - improvement_cis[i] * 1.1 - 0.5
            ax.annotate(sig_str, xy=(ov, y_pos), ha='center', fontsize=12, fontweight='bold')
    
    ax.set_xlabel('Overlap Size', fontsize=12)
    ax.set_ylabel('Improvement (With - Without)', fontsize=12)
    ax.set_title('Lagrangian Benefit vs Overlap', fontsize=12)
    ax.grid(alpha=0.3, axis='y')
    
    # Trend line
    if len(overlaps) > 2:
        z = np.polyfit(overlaps, improvements, 1)
        p = np.poly1d(z)
        x_line = np.linspace(min(overlaps), max(overlaps), 100)
        ax.plot(x_line, p(x_line), 'b--', alpha=0.7, linewidth=2, 
                label=f'Trend (slope={z[0]:.4f})')
        ax.legend(fontsize=10)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.savefig(save_path.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"\nSaved: {save_path}")
    plt.show()


def plot_dynamics_results(metrics_with, metrics_without, save_path='lagrangian_dynamics.png'):
    """Plot coordination dynamics comparison with confidence bands."""
    
    n_seeds = len(metrics_with)
    
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    
    # ========== Plot 1: Variance reduction comparison ==========
    ax = axes[0]
    
    for metrics_list, label, color in [
        (metrics_with, 'With Lagrangian', '#2ecc71'),
        (metrics_without, 'Without Lagrangian', '#e74c3c')
    ]:
        mean_var, std_var, x = pad_and_aggregate(metrics_list, 'variance')
        if len(mean_var) > 0:
            ax.plot(x, mean_var, label=label, color=color, linewidth=2)
            ax.fill_between(x, mean_var - std_var, mean_var + std_var, alpha=0.2, color=color)
    
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Variance at Overlap Positions', fontsize=12)
    ax.set_title(f'Variance Reduction (n={n_seeds} seeds)', fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    # ========== Plot 2: Disagreement reduction comparison ==========
    ax = axes[1]
    
    for metrics_list, label, color in [
        (metrics_with, 'With Lagrangian', '#2ecc71'),
        (metrics_without, 'Without Lagrangian', '#e74c3c')
    ]:
        mean_dis, std_dis, x = pad_and_aggregate(metrics_list, 'disagreement')
        if len(mean_dis) > 0:
            ax.plot(x, mean_dis, label=label, color=color, linewidth=2)
            ax.fill_between(x, mean_dis - std_dis, mean_dis + std_dis, alpha=0.2, color=color)
    
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Disagreement Score', fontsize=12)
    ax.set_title('Disagreement Reduction Comparison', fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    # ========== Plot 3: Reward trajectory comparison ==========
    ax = axes[2]
    
    for metrics_list, label, color in [
        (metrics_with, 'With Lagrangian', '#2ecc71'),
        (metrics_without, 'Without Lagrangian', '#e74c3c')
    ]:
        mean_rew, std_rew, x = pad_and_aggregate(metrics_list, 'rewards')
        if len(mean_rew) > 0:
            ax.plot(x, mean_rew, label=label, color=color, linewidth=2)
            ax.fill_between(x, mean_rew - std_rew, mean_rew + std_rew, alpha=0.2, color=color)
    
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Reward', fontsize=12)
    ax.set_title('Reward Trajectory', fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.savefig(save_path.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"\nSaved: {save_path}")
    plt.show()


# ============================================================================
# LATEX TABLE GENERATION
# ============================================================================

def generate_latex_table(results: Dict, caption: str = "Lagrangian Ablation Results") -> str:
    """Generate LaTeX table from results."""
    
    mean_with = np.mean(results['with'])
    std_with = np.std(results['with'])
    mean_without = np.mean(results['without'])
    std_without = np.std(results['without'])
    
    diffs = [w - wo for w, wo in zip(results['with'], results['without'])]
    mean_diff = np.mean(diffs)
    std_diff = np.std(diffs)
    
    # Handle edge case
    if std_diff == 0:
        t_stat, p_val = float('nan'), float('nan')
        sig_str = ''
    else:
        t_stat, p_val = paired_ttest(results['with'], results['without'])
        sig_str = '$^{***}$' if p_val < 0.001 else '$^{**}$' if p_val < 0.01 else '$^{*}$' if p_val < 0.05 else ''
    
    latex = f"""
\\begin{{table}}[h]
\\centering
\\caption{{{caption}}}
\\label{{tab:lagrangian-ablation}}
\\begin{{tabular}}{{lcc}}
\\toprule
Method & Reward & $\\Delta$ \\\\
\\midrule
Without Lagrangian & ${mean_without:.2f} \\pm {std_without:.2f}$ & -- \\\\
With Lagrangian & ${mean_with:.2f} \\pm {std_with:.2f}$ & ${mean_diff:+.2f} \\pm {std_diff:.2f}${sig_str} \\\\
\\bottomrule
\\end{{tabular}}
\\end{{table}}
"""
    return latex


def generate_coordination_latex_table(coord_metrics: Dict, caption: str = "Lagrangian Coordination Metrics") -> str:
    """Generate LaTeX table for coordination metrics (the real ablation evidence)."""
    
    final_var_with = np.mean(coord_metrics['final_var_with'])
    final_var_without = np.mean(coord_metrics['final_var_without'])
    var_reduction = coord_metrics['final_var_reduction']
    var_p = coord_metrics['final_var_p']
    
    final_dis_with = np.mean(coord_metrics['final_dis_with'])
    final_dis_without = np.mean(coord_metrics['final_dis_without'])
    dis_reduction = coord_metrics['final_dis_reduction']
    dis_p = coord_metrics['final_dis_p']
    
    var_sig = '$^{***}$' if var_p < 0.001 else '$^{**}$' if var_p < 0.01 else '$^{*}$' if var_p < 0.05 else ''
    dis_sig = '$^{***}$' if dis_p < 0.001 else '$^{**}$' if dis_p < 0.01 else '$^{*}$' if dis_p < 0.05 else ''
    
    latex = f"""
\\begin{{table}}[h]
\\centering
\\caption{{{caption}}}
\\label{{tab:lagrangian-coordination}}
\\begin{{tabular}}{{lccc}}
\\toprule
Metric & Without Lagrangian & With Lagrangian & Reduction \\\\
\\midrule
Final Overlap Variance & ${final_var_without:.4f}$ & $\\mathbf{{{final_var_with:.4f}}}$ & ${var_reduction:+.1f}\\%${var_sig} \\\\
Final Disagreement & ${final_dis_without:.4f}$ & $\\mathbf{{{final_dis_with:.4f}}}$ & ${dis_reduction:+.1f}\\%${dis_sig} \\\\
\\bottomrule
\\end{{tabular}}
\\vspace{{0.5em}}
\\\\
\\footnotesize{{$^{{*}}p < 0.05$, $^{{**}}p < 0.01$, $^{{***}}p < 0.001$ (paired t-test). {len(coord_metrics['final_var_with'])} seeds.}}
\\end{{table}}
"""
    return latex


# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Proper Lagrangian Ablation Tests (Multi-Seed)')
    parser.add_argument('--test', type=str, default='all',
                        choices=['all', 'single', 'overlap', 'dynamics'])
    parser.add_argument('--problem', type=str, default='TSP', choices=['TSP', 'Knapsack'])
    parser.add_argument('--size', type=int, default=50)
    parser.add_argument('--seeds', type=int, default=N_SEEDS, help='Number of seeds (default: 10)')
    parser.add_argument('--max_iters', type=int, default=400, help='Max iterations per run')
    parser.add_argument('--conflict', type=float, default=2.0,
                        help='Conflict strength (higher = more subproblem disagreement)')
    parser.add_argument('--save_dir', type=str, default='.')
    
    args = parser.parse_args()
    
    # Update global
    N_SEEDS = args.seeds
    
    # Problem class and kwargs (problem will be instantiated per-seed)
    if args.problem == 'TSP':
        problem_class = SubproblemConflictTSP
        problem_kwargs = {
            'n_cities': args.size,
            'n_regions': 5,
            'conflict_strength': args.conflict
        }
    else:
        problem_class = SubproblemConflictKnapsack
        problem_kwargs = {
            'n_items': args.size,
            'n_clusters': 5,
            'conflict_strength': args.conflict
        }
    
    # Test weights (single scalarization)
    weights = [0.5, 0.5]
    
    # Run tests
    if args.test in ['all', 'single']:
        print("\n" + "=" * 70)
        print("RUNNING: Single Scalarization Test")
        print("=" * 70)
        results = test_single_scalarization(
            problem_class, problem_kwargs, weights, 
            n_seeds=args.seeds, max_iters=args.max_iters
        )
        plot_single_scalarization_results(results, 
            save_path=f'{args.save_dir}/lagrangian_single_{args.problem}.png')
        
        # Generate LaTeX table
        latex_table = generate_latex_table(results, 
            caption=f"Lagrangian Ablation on {args.problem}-{args.size}")
        print("\n" + "=" * 70)
        print("LATEX TABLE:")
        print("=" * 70)
        print(latex_table)
    
    if args.test in ['all', 'overlap']:
        print("\n" + "=" * 70)
        print("RUNNING: Overlap Variation Test")
        print("=" * 70)
        overlap_values = [2, 4, 6, 8]
        results = test_vary_overlap(
            problem_class, problem_kwargs, weights, overlap_values, 
            n_seeds=args.seeds, max_iters=args.max_iters
        )
        plot_overlap_results(results,
            save_path=f'{args.save_dir}/lagrangian_overlap_{args.problem}.png')
    
    if args.test in ['all', 'dynamics']:
        print("\n" + "=" * 70)
        print("RUNNING: Coordination Dynamics Test")
        print("=" * 70)
        metrics_with, metrics_without = test_coordination_dynamics(
            problem_class, problem_kwargs, weights, 
            n_seeds=args.seeds, max_iters=args.max_iters
        )
        plot_dynamics_results(metrics_with, metrics_without,
            save_path=f'{args.save_dir}/lagrangian_dynamics_{args.problem}.png')
    
    print("\n" + "=" * 70)
    print("✅ Ablation tests complete!")
    print("=" * 70)