#!/usr/bin/env python3
"""Process trend gap table: compute best per round, aggregate by size, and visualize.

This script combines:
1. Compute best pool/population per round
2. Aggregate by size bins
3. Visualize trends

Usage:
  python3 scripts/process_and_visualize_trends.py \
    --gap_table /path/to/trend_gap_table.csv \
    --out_dir /path/to/output \
    --metric mean \
    --population_table /path/to/population_gap_table.csv  # optional
"""

import os
import sys
import argparse
import re
import subprocess
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def parse_args():
    # Default output directory is the same as script directory
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    # Default EoH optimal table path (in baseline_results/)
    # Try eoh_optimal_gap_table.csv first (has all dataset columns), then fallback to eoh_optimal_gap_by_size.csv
    _default_eoh_optimal_table = os.path.join(_script_dir, 'baseline_results', 'eoh_optimal_gap_table.csv')
    if not os.path.exists(_default_eoh_optimal_table):
        _default_eoh_optimal_table = os.path.join(_script_dir, 'baseline_results', 'eoh_optimal_gap_by_size.csv')
    
    p = argparse.ArgumentParser(
        description="Process trend gap table: compute best per round, aggregate by size, and visualize."
    )
    p.add_argument(
        '--gap_table',
        required=True,
        help='Input trend_gap_table.csv path (required: your method\'s gap table)'
    )
    p.add_argument(
        '--population_table',
        type=str,
        default=None,
        help='Optional: Path to CSV file containing population data (if different from gap_table). If not provided, population trend will not be plotted.'
    )
    p.add_argument(
        '--eoh_optimal_table',
        type=str,
        default=_default_eoh_optimal_table,
        help=f'Path to CSV file containing EoH optimal heuristic results (default: {_default_eoh_optimal_table})'
    )
    p.add_argument(
        '--out_dir',
        type=str,
        default=None,
        help='Output directory (default: same directory as gap_table)'
    )
    p.add_argument(
        '--metric',
        type=str,
        default='mean',
        choices=['mean', 'min', 'median'],
        help='Aggregation metric for finding best solver (default: mean)'
    )
    p.add_argument(
        '--dpi',
        type=int,
        default=150,
        help='DPI for output figures (default: 150)'
    )
    p.add_argument(
        '--pkl_dir',
        type=str,
        default=os.path.join(_script_dir, 'TestingData'),
        help='PKL directory for size inference (default: testing/TestingData)'
    )
    return p.parse_args()


# ==================== Step 2: Compute best per round ====================

def extract_pool_round_number(row_name: str) -> int:
    """Extract round number from row name like 'h0', 'h1', etc."""
    if not row_name.startswith('h') or len(row_name) < 2:
        return -1
    match = re.search(r'^h(\d+)$', row_name)
    return int(match.group(1)) if match else -1


def compute_aggregate_score(row: pd.Series, cols: list, metric: str) -> float:
    """Compute aggregate score from a row across specified columns."""
    values = []
    for col in cols:
        val = row.get(col)
        if pd.notna(val) and val is not None:
            try:
                values.append(float(val))
            except (ValueError, TypeError):
                pass
    
    if not values:
        return float('inf')
    
    if metric == 'mean':
        return float(np.mean(values))
    elif metric == 'min':
        return float(np.min(values))
    elif metric == 'median':
        return float(np.median(values))
    else:
        return float(np.mean(values))


def process_pool_best_per_round(df: pd.DataFrame, agg_cols: list, metric: str) -> pd.DataFrame:
    """Process pool: find best solver among h0 to hi for each round i.
    
    For each TSPLIB instance separately:
    - For round i, select the solver with lowest gap on that specific instance
    - TSP100 is handled separately (not included in TSPLIB instances)
    """
    pool_rows = df[df['solver'].str.match(r'^h\d+$', na=False)].copy()
    if len(pool_rows) == 0:
        return None
    
    pool_rows['round'] = pool_rows['solver'].apply(extract_pool_round_number)
    pool_rows = pool_rows[pool_rows['round'] >= 0].sort_values('round').reset_index(drop=True)
    
    if len(pool_rows) == 0:
        return None
    
    max_round = int(pool_rows['round'].max())
    
    # Separate TSPLIB instances from TSP100
    tsplib_cols = [c for c in agg_cols if c.upper() != 'TSP100']
    tsp100_col = 'TSP100' if 'TSP100' in agg_cols else None
    
    best_per_round = {}
    # Track previous round's best values to ensure monotonicity
    prev_best_values = {}  # {col: best_value}
    
    for target_round in range(max_round + 1):
        candidates = pool_rows[pool_rows['round'] <= target_round].copy()
        if len(candidates) == 0:
            continue
        
        # For each TSPLIB instance, find the best solver on that specific instance
        out_row = {
            'solver': f"h{target_round}",
            'best_solver_indices': {},  # Track which solver was best for each instance
        }
        
        # Process each TSPLIB instance separately
        for col in tsplib_cols:
            # Find candidate with lowest gap for this specific instance
            best_val = float('inf')
            best_candidate_idx = None
            
            for idx, candidate in candidates.iterrows():
                val = candidate.get(col)
                if pd.notna(val) and val is not None:
                    val_float = float(val)
                    if val_float < best_val:
                        best_val = val_float
                        best_candidate_idx = idx
            
            # Ensure monotonicity: compare with previous round's best value
            # The value should be <= previous round's value (monotonic decreasing)
            if col in prev_best_values:
                prev_best = prev_best_values[col]
                if prev_best is not None:
                    # If current round found no valid value, use previous round's value
                    if best_val == float('inf'):
                        best_val = prev_best
                    # If previous round's value is better (smaller), use that
                    elif prev_best < best_val:
                        best_val = prev_best
                        # Reset best_candidate_idx to find the solver with this value
                        best_candidate_idx = None
            
            # Set the output value
            if best_val != float('inf'):
                out_row[col] = best_val
                # Find which solver has this best value
                if best_candidate_idx is not None:
                    best_candidate = candidates.loc[best_candidate_idx]
                    out_row['best_solver_indices'][col] = int(best_candidate['round'])
                else:
                    # Find the solver with this value from all candidates
                    for idx, candidate in candidates.iterrows():
                        val = candidate.get(col)
                        if pd.notna(val) and val is not None and abs(float(val) - best_val) < 1e-9:
                            out_row['best_solver_indices'][col] = int(candidate['round'])
                            break
                prev_best_values[col] = best_val
            else:
                # If no valid candidate found and no previous value, set to None
                out_row[col] = None
        
        # For TSP100, also find best solver separately
        if tsp100_col:
            best_val_tsp100 = float('inf')
            best_candidate_idx_tsp100 = None
            
            for idx, candidate in candidates.iterrows():
                val = candidate.get(tsp100_col)
                if pd.notna(val) and val is not None:
                    val_float = float(val)
                    if val_float < best_val_tsp100:
                        best_val_tsp100 = val_float
                        best_candidate_idx_tsp100 = idx
            
            # Ensure monotonicity for TSP100
            if tsp100_col in prev_best_values:
                prev_best_tsp100 = prev_best_values[tsp100_col]
                if prev_best_tsp100 is not None:
                    # If current round found no valid value, use previous round's value
                    if best_val_tsp100 == float('inf'):
                        best_val_tsp100 = prev_best_tsp100
                    # If previous round's value is better (smaller), use that
                    elif prev_best_tsp100 < best_val_tsp100:
                        best_val_tsp100 = prev_best_tsp100
                        # Reset best_candidate_idx_tsp100 to find the solver with this value
                        best_candidate_idx_tsp100 = None
            
            if best_val_tsp100 != float('inf'):
                out_row[tsp100_col] = best_val_tsp100
                # Find which solver has this best value
                if best_candidate_idx_tsp100 is not None:
                    best_candidate = candidates.loc[best_candidate_idx_tsp100]
                    out_row['best_solver_indices'][tsp100_col] = int(best_candidate['round'])
                else:
                    # Find the solver with this value from all candidates
                    for idx, candidate in candidates.iterrows():
                        val = candidate.get(tsp100_col)
                        if pd.notna(val) and val is not None and abs(float(val) - best_val_tsp100) < 1e-9:
                            out_row['best_solver_indices'][tsp100_col] = int(candidate['round'])
                            break
                prev_best_values[tsp100_col] = best_val_tsp100
            else:
                # If no valid candidate found and no previous value, set to None
                out_row[tsp100_col] = None
        
        # Use the most frequently selected solver as the representative
        if out_row['best_solver_indices']:
            solver_counts = Counter(out_row['best_solver_indices'].values())
            most_common_solver_idx = solver_counts.most_common(1)[0][0]
            best_solver_row = candidates[candidates['round'] == most_common_solver_idx].iloc[0]
            out_row['best_solver_idx'] = most_common_solver_idx
            out_row['best_solver_name'] = str(best_solver_row['solver'])
        else:
            out_row['best_solver_idx'] = None
            out_row['best_solver_name'] = None
        
        # Remove internal tracking
        del out_row['best_solver_indices']
        
        best_per_round[target_round] = out_row
    
    return pd.DataFrame(list(best_per_round.values()))


def process_population_latest(df: pd.DataFrame, agg_cols: list) -> pd.DataFrame:
    """Process population: use latest generation directly (no best selection).
    
    Note: df should only contain population rows, and agg_cols should match
    the columns available in the target CSV (may have missing columns filled with None).
    """
    pop_rows = df[df['solver'].str.match(r'^h\d+$', na=False)].copy()
    if len(pop_rows) == 0:
        return None
    
    pop_rows['generation'] = pop_rows['solver'].apply(extract_pool_round_number)
    pop_rows = pop_rows[pop_rows['generation'] >= 0].sort_values('generation').reset_index(drop=True)
    
    if len(pop_rows) == 0:
        return None
    
    out_rows = []
    for _, row in pop_rows.iterrows():
        gen = int(row['generation'])
        out_row = {
            'solver': f"population_best_g{gen}",
            'generation': gen,
            'original_name': str(row['solver'])
        }
        
        for col in agg_cols:
            val = row.get(col)
            out_row[col] = float(val) if pd.notna(val) and val is not None else None
        
        out_rows.append(out_row)
    
    return pd.DataFrame(out_rows)


# ==================== Step 3: Aggregate by size ====================

def extract_size_from_label(label: str):
    """Extract size from label name (e.g., BERLIN52 -> 52)."""
    if not label or not isinstance(label, str):
        return None
    match = re.search(r'(\d+)$', str(label))
    if match:
        try:
            return int(match.group(1))
        except (ValueError, TypeError):
            return None
    return None


def bin_of(size: int) -> str:
    """Map size to bin.
    
    Bins:
    - lt50: < 50
    - s50_100: 50-100
    - s100_200: 100-200
    - s200_500: 200-500
    - gt500: > 500
    """
    if size < 50:
        return 'lt50'
    if size < 100:
        return 's50_100'
    if size < 200:
        return 's100_200'
    if size < 500:
        return 's200_500'
    return 'gt500'


def format_size_label(size) -> str:
    """Format size or bin name for display.
    
    Args:
        size: int (numeric size) or str (bin name like 'lt50', 's50_100', etc.)
    
    Returns:
        Formatted label string
    """
    if isinstance(size, str):
        # Bin name format
        bin_labels = {
            'lt50': '<50',
            's50_100': '50-100',
            's100_200': '100-200',
            's200_500': '200-500',
            'gt500': '>500'
        }
        return bin_labels.get(size.lower(), size)
    elif isinstance(size, (int, float)):
        # Numeric size
        return str(int(size))
    else:
        return str(size)


def extract_dataset_type(col_name: str) -> tuple:
    """
    Extract dataset type and size from column name.
    
    Returns: (dataset_type, size) or (dataset_type, None)
    
    Examples:
        'BERLIN52' -> ('TSPLIB', 52)
        'EIL51' -> ('TSPLIB', 51)
        'TSP100' -> ('TSP100', None)
        'TSPLIB_52' -> ('TSPLIB', 52)  # Aggregated format
    """
    if not col_name or not isinstance(col_name, str):
        return (None, None)
    
    col_upper = str(col_name).upper()
    
    # TSP100 is a special dataset
    if col_upper == 'TSP100':
        return ('TSP100', None)
    
    # Check for aggregated format: TSPLIB_<bin> (size bins: lt50, s50_100, s100_200, s200_500, gt500)
    bin_match = re.match(r'^TSPLIB_(LT50|S50_100|S100_200|S200_500|GT500)$', col_upper)
    if bin_match:
        bin_name = col_upper.split('_', 1)[1].lower()  # Get bin name in lowercase
        # Return bin name as size (will be handled specially in visualization)
        return ('TSPLIB', bin_name)
    
    # Check for aggregated format: TSPLIB_<size> (numeric size)
    match = re.match(r'^TSPLIB_(\d+)$', col_upper)
    if match:
        try:
            return ('TSPLIB', int(match.group(1)))
        except (ValueError, TypeError):
            pass
    
    # TSPLIB instances: extract size from name (e.g., BERLIN52 -> 52)
    size = extract_size_from_label(col_name)
    if size is not None:
        return ('TSPLIB', size)
    
    # If we get here, the column name doesn't match any known pattern
    # This could be a metadata column or an unknown dataset
    return (None, None)


def aggregate_by_size(df: pd.DataFrame, pkl_dir: str = None) -> pd.DataFrame:
    """Aggregate gap table by size bins.
    
    TSP100 is kept separately (not aggregated into bins) since it's random uniform distribution.
    """
    # Determine columns (exclude solver and metadata columns)
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    # Separate TSPLIB instances from TSP100
    tsplib_cols = [c for c in all_cols if c.upper() != 'TSP100']
    tsp100_col = 'TSP100' if 'TSP100' in all_cols else None
    
    # Map TSPLIB columns to bins (TSP100 excluded)
    col2bin = {}
    for c in tsplib_cols:
        label_upper = c.upper()
        size = extract_size_from_label(label_upper)
        if size is not None:
            col2bin[c] = bin_of(size)
        else:
            # Could add pkl fallback here if needed
            pass
    
    bins = ['le200', 's200_500', 's500_1000', 'gt1000']
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        bin2vals = {b: [] for b in bins}
        
        # Aggregate TSPLIB instances by size bins
        for c, b in col2bin.items():
            val = row.get(c)
            if pd.notna(val) and val is not None:
                try:
                    bin2vals[b].append(float(val))
                except (ValueError, TypeError):
                    pass
        
        out_row = {'solver': solver}
        
        # Add size bin averages
        for b in bins:
            vals = bin2vals[b]
            out_row[b] = None if not vals else float(np.mean(vals))
        
        # Keep TSP100 as separate column
        if tsp100_col:
            out_row[tsp100_col] = row.get(tsp100_col) if pd.notna(row.get(tsp100_col)) else None
        
        out_rows.append(out_row)
    
    # Build output columns
    output_cols = ['solver'] + bins
    if tsp100_col:
        output_cols.append(tsp100_col)
    
    return pd.DataFrame(out_rows, columns=output_cols)


def aggregate_by_dataset_groups(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by dataset groups.
    
    Groups:
    - TSPLIB: aggregate by size (extracted from instance names like BERLIN52 -> 52)
    - TSP100: single column (no size)
    """
    # Determine columns (exclude solver and metadata columns)
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    if not all_cols:
        print(f"⚠️  Warning: No data columns found in aggregate_by_dataset_groups")
        print(f"   Available columns: {list(df.columns)}")
        return df
    
    # Group columns by dataset type and size
    dataset_groups = {}  # {dataset_type: {size: [cols]}} or {dataset_type: {'no_size': [cols]}}
    
    for col in all_cols:
        dataset_type, size = extract_dataset_type(col)
        if dataset_type:
            if dataset_type not in dataset_groups:
                dataset_groups[dataset_type] = {}
            
            if size is not None:
                # Dataset with size info (TSPLIB instances)
                if size not in dataset_groups[dataset_type]:
                    dataset_groups[dataset_type][size] = []
                dataset_groups[dataset_type][size].append(col)
            else:
                # Dataset without size info (TSP100)
                if 'no_size' not in dataset_groups[dataset_type]:
                    dataset_groups[dataset_type]['no_size'] = []
                dataset_groups[dataset_type]['no_size'].append(col)
    
    if not dataset_groups:
        print(f"  Warning: No dataset groups found in aggregate_by_dataset_groups")
        print(f"   Processed {len(all_cols)} columns but none matched dataset types")
        print(f"   Sample columns: {all_cols[:10]}")
        return df
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # Process each dataset type
        for dataset_type, size_groups in dataset_groups.items():
            if 'no_size' in size_groups:
                # Datasets without size: take the mean of all columns
                cols = size_groups['no_size']
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                if values:
                    out_row[dataset_type] = float(np.mean(values))
                else:
                    out_row[dataset_type] = None
            
            # Process datasets with size
            for size, cols in size_groups.items():
                if size == 'no_size':
                    continue
                
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                # Create column name: <DatasetType>_<size>
                col_name = f"{dataset_type}_{size}"
                out_row[col_name] = None if not values else float(np.mean(values))
        
        out_rows.append(out_row)
    
    # Build output columns in sorted order
    output_cols = ['solver']
    
    # Order: TSPLIB first (with sizes), then TSP100
    dataset_order = ['TSPLIB', 'TSP100']
    
    for dataset_type in dataset_order:
        if dataset_type not in dataset_groups:
            continue
        
        if 'no_size' in dataset_groups[dataset_type]:
            # Add dataset without size
            output_cols.append(dataset_type)
        
        # Add datasets with size (sorted by size)
        sizes = [s for s in dataset_groups[dataset_type].keys() if s != 'no_size']
        for size in sorted(sizes):
            output_cols.append(f"{dataset_type}_{size}")
    
    # Ensure all output columns exist in out_rows (fill with None if missing)
    for out_row in out_rows:
        for col in output_cols:
            if col not in out_row:
                out_row[col] = None
    
    result_df = pd.DataFrame(out_rows, columns=output_cols)
    
    # Debug: print summary
    print(f"   Aggregated {len(result_df)} rows into {len(output_cols)} columns")
    print(f"   Output columns: {output_cols[:10]}{'...' if len(output_cols) > 10 else ''}")
    
    return result_df


def aggregate_by_dataset_groups_with_custom_bins(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by dataset groups with custom size bins.
    
    Groups:
    - TSPLIB: aggregate by size bins (S: n≤200, M: 201≤n≤500, L: 501≤n≤1000, XL: n>1000)
    - TSP100: single column (no size)
    
    This function groups TSPLIB instances by custom size bins.
    """
    # Determine columns (exclude solver and metadata columns)
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    if not all_cols:
        print(f"⚠️  Warning: No data columns found in aggregate_by_dataset_groups_with_custom_bins")
        print(f"   Available columns: {list(df.columns)}")
        return df
    
    def size_to_bin(size: int) -> str:
        """Map size to custom bin: S (≤200), M (201-500), L (501-1000), XL (>1000)"""
        if size <= 200:
            return 'S'
        elif size <= 500:
            return 'M'
        elif size <= 1000:
            return 'L'
        else:
            return 'XL'
    
    # Group columns by dataset type and size bins
    dataset_groups = {}  # {dataset_type: {bin: [cols]}} or {dataset_type: {'no_size': [cols]}}
    size_bins = ['S', 'M', 'L', 'XL']
    
    for col in all_cols:
        dataset_type, size = extract_dataset_type(col)
        if dataset_type:
            if dataset_type not in dataset_groups:
                dataset_groups[dataset_type] = {bin_name: [] for bin_name in size_bins}
                dataset_groups[dataset_type]['no_size'] = []
            
            if size is not None:
                # Dataset with size info (TSPLIB instances): map to size bin
                bin_name = size_to_bin(size)
                dataset_groups[dataset_type][bin_name].append(col)
            else:
                # Dataset without size info (TSP100)
                dataset_groups[dataset_type]['no_size'].append(col)
    
    if not dataset_groups:
        print(f"  Warning: No dataset groups found in aggregate_by_dataset_groups_with_custom_bins")
        print(f"   Processed {len(all_cols)} columns but none matched dataset types")
        print(f"   Sample columns: {all_cols[:10]}")
        return df
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # Process each dataset type
        for dataset_type, bin_groups in dataset_groups.items():
            if 'no_size' in bin_groups:
                # Datasets without size (TSP100): take the mean of all columns
                cols = bin_groups['no_size']
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                if values:
                    out_row[dataset_type] = float(np.mean(values))
                else:
                    out_row[dataset_type] = None
            
            # Process datasets with size bins
            for bin_name in size_bins:
                cols = bin_groups[bin_name]
                if not cols:
                    continue
                
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                # Create column name: <DatasetType>_<bin>
                col_name = f"{dataset_type}_{bin_name}"
                out_row[col_name] = None if not values else float(np.mean(values))
        
        out_rows.append(out_row)
    
    # Build output columns in sorted order
    output_cols = ['solver']
    
    # Order: TSPLIB first (with bins S, M, L, XL), then TSP100
    dataset_order = ['TSPLIB', 'TSP100']
    
    for dataset_type in dataset_order:
        if dataset_type not in dataset_groups:
            continue
        
        # Add size bins in order: S, M, L, XL (only if there are columns in that bin)
        for bin_name in size_bins:
            if bin_name in dataset_groups[dataset_type] and dataset_groups[dataset_type][bin_name]:
                col_name = f"{dataset_type}_{bin_name}"
                output_cols.append(col_name)
        
        # Add dataset without size (TSP100)
        if 'no_size' in dataset_groups[dataset_type] and dataset_groups[dataset_type]['no_size']:
            output_cols.append(dataset_type)
    
    # Ensure all output columns exist in out_rows (fill with None if missing)
    for out_row in out_rows:
        for col in output_cols:
            if col not in out_row:
                out_row[col] = None
    
    result_df = pd.DataFrame(out_rows, columns=output_cols)
    
    # Debug: print summary
    print(f"   Aggregated {len(result_df)} rows into {len(output_cols)} columns")
    print(f"   Output columns: {output_cols}")
    
    return result_df


def aggregate_by_dataset_groups_with_size_bins(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by dataset groups with size bins.
    
    Groups:
    - TSPLIB: aggregate by size bins (lt50, s50_100, s100_200, s200_500, gt500)
    - TSP100: single column (no size)
    
    This function groups TSPLIB instances by size bins instead of individual sizes.
    """
    # Determine columns (exclude solver and metadata columns)
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    if not all_cols:
        print(f"  Warning: No data columns found in aggregate_by_dataset_groups_with_size_bins")
        print(f"   Available columns: {list(df.columns)}")
        return df
    
    # Group columns by dataset type and size bins
    dataset_groups = {}  # {dataset_type: {bin: [cols]}} or {dataset_type: {'no_size': [cols]}}
    size_bins = ['lt50', 's50_100', 's100_200', 's200_500', 'gt500']
    
    for col in all_cols:
        dataset_type, size = extract_dataset_type(col)
        if dataset_type:
            if dataset_type not in dataset_groups:
                dataset_groups[dataset_type] = {bin_name: [] for bin_name in size_bins}
                dataset_groups[dataset_type]['no_size'] = []
            
            if size is not None:
                # Dataset with size info (TSPLIB instances): map to size bin
                bin_name = bin_of(size)
                dataset_groups[dataset_type][bin_name].append(col)
            else:
                # Dataset without size info (TSP100)
                dataset_groups[dataset_type]['no_size'].append(col)
    
    if not dataset_groups:
        print(f"  Warning: No dataset groups found in aggregate_by_dataset_groups_with_size_bins")
        print(f"   Processed {len(all_cols)} columns but none matched dataset types")
        print(f"   Sample columns: {all_cols[:10]}")
        return df
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # Process each dataset type
        for dataset_type, bin_groups in dataset_groups.items():
            if 'no_size' in bin_groups and bin_groups['no_size']:
                # Datasets without size: take the mean of all columns
                cols = bin_groups['no_size']
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                if values:
                    out_row[dataset_type] = float(np.mean(values))
                else:
                    out_row[dataset_type] = None
            
            # Process datasets with size bins
            for bin_name in size_bins:
                cols = bin_groups.get(bin_name, [])
                if not cols:
                    continue
                
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                # Create column name: <DatasetType>_<bin>
                col_name = f"{dataset_type}_{bin_name}"
                out_row[col_name] = None if not values else float(np.mean(values))
        
        out_rows.append(out_row)
    
    # Build output columns in sorted order
    output_cols = ['solver']
    
    # Order: TSPLIB first (with size bins), then TSP100
    dataset_order = ['TSPLIB', 'TSP100']
    
    for dataset_type in dataset_order:
        if dataset_type not in dataset_groups:
            continue
        
        if 'no_size' in dataset_groups[dataset_type] and dataset_groups[dataset_type]['no_size']:
            # Add dataset without size
            output_cols.append(dataset_type)
        
        # Add datasets with size bins (in order)
        for bin_name in size_bins:
            col_name = f"{dataset_type}_{bin_name}"
            # Only add if there are columns in this bin
            if dataset_groups[dataset_type].get(bin_name):
                output_cols.append(col_name)
    
    # Ensure all output columns exist in out_rows (fill with None if missing)
    for out_row in out_rows:
        for col in output_cols:
            if col not in out_row:
                out_row[col] = None
    
    result_df = pd.DataFrame(out_rows, columns=output_cols)
    
    # Debug: print summary
    print(f"   Aggregated {len(result_df)} rows into {len(output_cols)} columns (with size bins)")
    print(f"   Output columns: {output_cols[:10]}{'...' if len(output_cols) > 10 else ''}")
    
    return result_df


# ==================== Step 4: Visualize ====================

def extract_generation(row_name: str) -> tuple:
    """Extract generation info from row name. Returns (type, gen_num)."""
    # Check for h0, h1, h2 format (pool best)
    if row_name.startswith('h') and len(row_name) > 1:
        match = re.search(r'^h(\d+)$', row_name)
        if match:
            return ('pool', int(match.group(1)))
    # Also support old format pool_best_r* for backward compatibility
    if row_name.startswith('pool_best_r'):
        match = re.search(r'pool_best_r(\d+)', row_name)
        if match:
            return ('pool', int(match.group(1)))
    elif row_name.startswith('population_best_g'):
        match = re.search(r'population_best_g(\d+)', row_name)
        if match:
            return ('popbest', int(match.group(1)))
    return (None, None)


def visualize_trend_table(df: pd.DataFrame, out_path: str, dpi: int = 150, eoh_optimal_df: pd.DataFrame = None):
    """Visualize trend table: pool vs population across TSPLIB avg and TSP100."""
    # Find TSPLIB columns and TSP100
    all_cols = [c for c in df.columns if c not in ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']]
    tsplib_cols = [c for c in all_cols if c.upper() != 'TSP100']
    tsp100_col = 'TSP100' if 'TSP100' in df.columns else None
    
    pool_data = {'generation': [], 'tsplib_avg': [], 'tsp100': []}
    popbest_data = {'generation': [], 'tsplib_avg': [], 'tsp100': []}
    
    # Extract EoH optimal values
    eoh_tsplib_avg = None
    eoh_tsp100 = None
    if eoh_optimal_df is not None:
        eoh_row = eoh_optimal_df[eoh_optimal_df['solver'] == 'eoh_optimal']
        if len(eoh_row) > 0:
            # Compute TSPLIB average from EoH optimal
            eoh_tsplib_vals = []
            for col in tsplib_cols:
                val = eoh_row.iloc[0].get(col)
                if pd.notna(val) and val is not None:
                    eoh_tsplib_vals.append(float(val))
            if eoh_tsplib_vals:
                eoh_tsplib_avg = float(np.mean(eoh_tsplib_vals))
            
            # Get TSP100 value
            if tsp100_col:
                val = eoh_row.iloc[0].get(tsp100_col)
                if pd.notna(val) and val is not None:
                    eoh_tsp100 = float(val)
    
    for _, row in df.iterrows():
        solver = row['solver']
        typ, gen = extract_generation(solver)
        
        if typ is None:
            continue
        
        # Compute TSPLIB average
        tsplib_vals = []
        for col in tsplib_cols:
            val = row.get(col)
            if pd.notna(val) and val is not None:
                tsplib_vals.append(float(val))
        
        if tsplib_vals:
            tsplib_avg = np.mean(tsplib_vals)
            tsp100_val = float(row[tsp100_col]) if tsp100_col and pd.notna(row.get(tsp100_col)) else None
            
            if typ == 'pool':
                pool_data['generation'].append(gen)
                pool_data['tsplib_avg'].append(tsplib_avg)
                pool_data['tsp100'].append(tsp100_val)
            elif typ == 'popbest':
                popbest_data['generation'].append(gen)
                popbest_data['tsplib_avg'].append(tsplib_avg)
                popbest_data['tsp100'].append(tsp100_val)
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot TSPLIB average
    if pool_data['generation']:
        pool_gen, pool_avg = zip(*sorted(zip(pool_data['generation'], pool_data['tsplib_avg'])))
        pool_gen = list(pool_gen)
        pool_avg = list(pool_avg)
        ax1.plot(pool_gen, pool_avg, 'o-', label='Co-evolution', linewidth=2, markersize=6)
        
        if pool_gen and pool_avg:
            final_gen = pool_gen[-1]
            final_avg = pool_avg[-1]
            ax1.annotate(f'{final_avg:.2f}%', 
                        xy=(final_gen, final_avg), 
                        xytext=(5, 5), 
                        textcoords='offset points',
                        fontsize=9,
                        bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
                        ha='left')
    
    if popbest_data['generation']:
        pop_gen, pop_avg = zip(*sorted(zip(popbest_data['generation'], popbest_data['tsplib_avg'])))
        pop_gen = list(pop_gen)
        pop_avg = list(pop_avg)
        ax1.plot(pop_gen, pop_avg, 's-', label='EoH', linewidth=2, markersize=6)
        
        if pop_gen and pop_avg:
            final_gen = pop_gen[-1]
            final_avg = pop_avg[-1]
            ax1.annotate(f'{final_avg:.2f}%', 
                        xy=(final_gen, final_avg), 
                        xytext=(5, -15), 
                        textcoords='offset points',
                        fontsize=9,
                        bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcoral', alpha=0.7),
                        ha='left')
    
    # Add EoH optimal line for TSPLIB average
    if eoh_tsplib_avg is not None:
        ax1.axhline(y=eoh_tsplib_avg, color='red', linestyle='--', linewidth=2, 
                   label=f'EoH Optimal ({eoh_tsplib_avg:.2f}%)', alpha=0.8, zorder=1)
    
    ax1.set_xlabel('Round/Generation', fontsize=11)
    ax1.set_ylabel('Gap (%)', fontsize=11)
    ax1.set_title('TSPLIB Average', fontsize=12)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot TSP100
    if tsp100_col:
        if pool_data['generation']:
            pool_gen, pool_tsp100 = zip(*sorted(zip(pool_data['generation'], pool_data['tsp100'])))
            pool_tsp100_clean = [v for v in pool_tsp100 if v is not None]
            pool_gen_clean = [g for g, v in zip(pool_gen, pool_tsp100) if v is not None]
            if pool_tsp100_clean:
                ax2.plot(pool_gen_clean, pool_tsp100_clean, 'o-', label='Co-evolution', linewidth=2, markersize=6)
                
                if pool_gen_clean and pool_tsp100_clean:
                    final_gen = pool_gen_clean[-1]
                    final_tsp100 = pool_tsp100_clean[-1]
                    ax2.annotate(f'{final_tsp100:.2f}%', 
                                xy=(final_gen, final_tsp100), 
                                xytext=(5, 5), 
                                textcoords='offset points',
                                fontsize=9,
                                bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
                                ha='left')
        
        if popbest_data['generation']:
            pop_gen, pop_tsp100 = zip(*sorted(zip(popbest_data['generation'], popbest_data['tsp100'])))
            pop_tsp100_clean = [v for v in pop_tsp100 if v is not None]
            pop_gen_clean = [g for g, v in zip(pop_gen, pop_tsp100) if v is not None]
            if pop_tsp100_clean:
                ax2.plot(pop_gen_clean, pop_tsp100_clean, 's-', label='EoH', linewidth=2, markersize=6)
                
                if pop_gen_clean and pop_tsp100_clean:
                    final_gen = pop_gen_clean[-1]
                    final_tsp100 = pop_tsp100_clean[-1]
                    ax2.annotate(f'{final_tsp100:.2f}%', 
                                xy=(final_gen, final_tsp100), 
                                xytext=(5, -15), 
                                textcoords='offset points',
                                fontsize=9,
                                bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcoral', alpha=0.7),
                                ha='left')
    
    # Add EoH optimal line for TSP100
    if eoh_tsp100 is not None:
        ax2.axhline(y=eoh_tsp100, color='red', linestyle='--', linewidth=2, 
                   label=f'EoH Optimal ({eoh_tsp100:.2f}%)', alpha=0.8, zorder=1)
    
    ax2.set_xlabel('Round/Generation', fontsize=11)
    ax2.set_ylabel('Gap (%)', fontsize=11)
    ax2.set_title('TSP100', fontsize=12)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f" Saved trend visualization: {out_path}")


def visualize_by_dataset_sizes(df: pd.DataFrame, out_path: str, dpi: int = 150, eoh_optimal_dataset_df: pd.DataFrame = None):
    """Visualize trend by size bins.
    
    Creates one subplot for each size bin (lt50, s50_100, s100_200, s200_500, gt500),
    and optionally TSP100, all in one figure.
    
    Args:
        df: Main trend data with size bins (from aggregate_by_dataset_groups_with_size_bins)
        out_path: Output path for visualization
        dpi: DPI for output figure
        eoh_optimal_dataset_df: Optional DataFrame with EoH optimal results
    """
    # Find dataset columns
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    print(f"   Found {len(all_cols)} columns to process")
    print(f"   Sample columns: {all_cols[:10]}")
    
    # Group columns by size bin
    size_bins = {}  # {size_bin: {'TSPLIB': col_name or None, 'TSP100': col_name or None}}
    bin_order = ['lt50', 's50_100', 's100_200', 's200_500', 'gt500']
    
    for col in all_cols:
        try:
            col_str = str(col)
            dataset_type, size = extract_dataset_type(col_str)
            if dataset_type == 'TSPLIB' and size in bin_order:
                # TSPLIB size bin
                if size not in size_bins:
                    size_bins[size] = {'TSPLIB': None, 'TSP100': None}
                size_bins[size]['TSPLIB'] = col
            elif dataset_type == 'TSP100' and size is None:
                # TSP100 (no size)
                # Add TSP100 to all bins (or create a special entry)
                if 'TSP100' not in size_bins:
                    size_bins['TSP100'] = {'TSPLIB': None, 'TSP100': None}
                size_bins['TSP100']['TSP100'] = col
        except Exception as e:
            print(f"    Error processing column '{col}': {e}")
            continue
    
    if not size_bins:
        print(f"  No size bin columns found, skipping dataset visualization")
        print(f"   Available columns: {all_cols[:20]}")
        return
    
    print(f"   Found {len(size_bins)} size bins: {list(size_bins.keys())}")
    
    # Extract EoH optimal values
    eoh_optimal_gaps = {}  # {(size_bin, dataset_type): gap}
    if eoh_optimal_dataset_df is not None:
        eoh_row = eoh_optimal_dataset_df[eoh_optimal_dataset_df['solver'] == 'eoh_optimal']
        if len(eoh_row) > 0:
            print(f"   Found EoH optimal row, extracting values...")
            eoh_all_cols = [c for c in eoh_optimal_dataset_df.columns if c not in metadata_cols]
            for col in eoh_all_cols:
                col_str = str(col)
                dataset_type, size = extract_dataset_type(col_str)
                if dataset_type == 'TSPLIB' and size in bin_order:
                    val = eoh_row.iloc[0].get(col)
                    if pd.notna(val) and val is not None:
                        key = (size, 'TSPLIB')
                        eoh_optimal_gaps[key] = float(val)
                elif dataset_type == 'TSP100' and size is None:
                    val = eoh_row.iloc[0].get(col)
                    if pd.notna(val) and val is not None:
                        key = ('TSP100', 'TSP100')
                        eoh_optimal_gaps[key] = float(val)
            print(f"   Extracted {len(eoh_optimal_gaps)} EoH optimal values")
    
    # Collect data for each size bin
    pool_data = {}  # {size_bin: {'generation': [], 'gap': []}}
    popbest_data = {}  # {size_bin: {'generation': [], 'gap': []}}
    
    for _, row in df.iterrows():
        solver = row['solver']
        typ, gen = extract_generation(solver)
        
        if typ is None:
            continue
        
        # For each size bin, get the TSPLIB value
        for size_bin, cols_dict in size_bins.items():
            if size_bin == 'TSP100':
                # Handle TSP100 separately
                col = cols_dict.get('TSP100')
                if col:
                    val = row.get(col) if col in row.index else None
                    if pd.notna(val) and val is not None:
                        key = 'TSP100'
                        if key not in pool_data:
                            pool_data[key] = {'generation': [], 'gap': []}
                        if key not in popbest_data:
                            popbest_data[key] = {'generation': [], 'gap': []}
                        
                        if typ == 'pool':
                            pool_data[key]['generation'].append(gen)
                            pool_data[key]['gap'].append(float(val))
                        elif typ == 'popbest':
                            popbest_data[key]['generation'].append(gen)
                            popbest_data[key]['gap'].append(float(val))
            else:
                # Handle TSPLIB size bins
                col = cols_dict.get('TSPLIB')
                if col:
                    val = row.get(col) if col in row.index else None
                    if pd.notna(val) and val is not None:
                        key = size_bin
                        if key not in pool_data:
                            pool_data[key] = {'generation': [], 'gap': []}
                        if key not in popbest_data:
                            popbest_data[key] = {'generation': [], 'gap': []}
                        
                        if typ == 'pool':
                            pool_data[key]['generation'].append(gen)
                            pool_data[key]['gap'].append(float(val))
                        elif typ == 'popbest':
                            popbest_data[key]['generation'].append(gen)
                            popbest_data[key]['gap'].append(float(val))
    
    # Create subplots: one for each size bin (plus TSP100 if exists)
    # Sort: bin_order first, then TSP100
    plot_order = [b for b in bin_order if b in size_bins]
    if 'TSP100' in size_bins:
        plot_order.append('TSP100')
    
    num_plots = len(plot_order)
    if num_plots == 0:
        print("  No size bins to plot, skipping dataset visualization")
        return
    
    # Arrange subplots in a grid (2 columns)
    cols = 2
    rows = (num_plots + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(14, 6 * rows))
    
    # Convert axes to a flat list
    if not isinstance(axes, np.ndarray):
        axes = [axes]
    else:
        axes = axes.flatten() if axes.ndim > 1 else list(axes)
    
    # Plot each size bin
    for idx, size_bin in enumerate(plot_order):
        ax = axes[idx]
        size_label = format_size_label(size_bin) if size_bin != 'TSP100' else 'TSP100'
        
        # Plot pool data
        if size_bin in pool_data and pool_data[size_bin]['generation']:
            pool_gen, pool_gap = zip(*sorted(zip(pool_data[size_bin]['generation'], pool_data[size_bin]['gap'])))
            pool_gen = list(pool_gen)
            pool_gap = list(pool_gap)
            ax.plot(pool_gen, pool_gap, 'o-', label='Co-evolution', 
                   linewidth=2, markersize=6, color='blue', alpha=0.8, zorder=3)
        
        # Plot popbest data
        if size_bin in popbest_data and popbest_data[size_bin]['generation']:
            pop_gen, pop_gap = zip(*sorted(zip(popbest_data[size_bin]['generation'], popbest_data[size_bin]['gap'])))
            pop_gen = list(pop_gen)
            pop_gap = list(pop_gap)
            ax.plot(pop_gen, pop_gap, 's--', label='EoH', 
                   linewidth=2, markersize=6, color='green', alpha=0.8, zorder=3)
        
        # Add EoH optimal line
        dataset_type = 'TSP100' if size_bin == 'TSP100' else 'TSPLIB'
        key = (size_bin, dataset_type)
        if key in eoh_optimal_gaps:
            eoh_gap = eoh_optimal_gaps[key]
            ax.axhline(y=eoh_gap, color='red', linestyle=':', linewidth=2, 
                      label=f'EoH Optimal ({eoh_gap:.2f}%)', 
                      alpha=0.8, zorder=2)
        
        ax.set_xlabel('Round/Generation', fontsize=11)
        ax.set_ylabel('Gap (%)', fontsize=11)
        ax.set_title(f'{size_label}', fontsize=12, fontweight='bold')
        ax.legend(fontsize=9, loc='best')
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(num_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f" Saved dataset visualization: {out_path}")


def visualize_by_size_bins(df: pd.DataFrame, out_path: str, dpi: int = 150, eoh_optimal_size_df: pd.DataFrame = None):
    """Visualize trend by size bins.
    
    Args:
        df: Main trend data with size bins
        out_path: Output path for visualization
        dpi: DPI for output figure
        eoh_optimal_size_df: Optional DataFrame with EoH optimal results aggregated by size bins
    """
    if not all(col in df.columns for col in ['le200', 's200_500', 's500_1000', 'gt1000']):
        print("  Missing size bin columns, skipping size bin visualization")
        return
    
    bins = ['le200', 's200_500', 's500_1000', 'gt1000']
    bin_labels = ['≤200', '200-500', '500-1000', '>1000']
    
    pool_data = {b: {'generation': [], 'gap': []} for b in bins}
    popbest_data = {b: {'generation': [], 'gap': []} for b in bins}
    
    # Extract EoH optimal values for each bin
    eoh_optimal_gaps = {}
    if eoh_optimal_size_df is not None:
        print(f"   EoH optimal size_df shape: {eoh_optimal_size_df.shape}")
        print(f"   EoH optimal size_df columns: {list(eoh_optimal_size_df.columns)}")
        print(f"   EoH optimal size_df solvers: {list(eoh_optimal_size_df['solver'].values)}")
        eoh_row = eoh_optimal_size_df[eoh_optimal_size_df['solver'] == 'eoh_optimal']
        print(f"   Found {len(eoh_row)} eoh_optimal rows")
        if len(eoh_row) > 0:
            for b in bins:
                val = eoh_row.iloc[0].get(b)
                if pd.notna(val) and val is not None:
                    eoh_optimal_gaps[b] = float(val)
                    print(f"   EoH optimal {b}: {eoh_optimal_gaps[b]:.4f}%")
        else:
            print(f"   No eoh_optimal row found in eoh_optimal_size_df")
    else:
        print(f"   eoh_optimal_size_df is None")
    
    for _, row in df.iterrows():
        solver = row['solver']
        typ, gen = extract_generation(solver)
        
        if typ is None:
            continue
        
        for b in bins:
            val = row.get(b)
            if pd.notna(val) and val is not None:
                if typ == 'pool':
                    pool_data[b]['generation'].append(gen)
                    pool_data[b]['gap'].append(float(val))
                elif typ == 'popbest':
                    popbest_data[b]['generation'].append(gen)
                    popbest_data[b]['gap'].append(float(val))
    
    # Create figure with 4 subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, (bin_name, bin_label) in enumerate(zip(bins, bin_labels)):
        ax = axes[idx]
        
        if pool_data[bin_name]['generation']:
            pool_gen, pool_gap = zip(*sorted(zip(pool_data[bin_name]['generation'], pool_data[bin_name]['gap'])))
            pool_gen = list(pool_gen)
            pool_gap = list(pool_gap)
            ax.plot(pool_gen, pool_gap, 'o-', label='Co-evolution', linewidth=2, markersize=5, zorder=3)
            
            if pool_gen and pool_gap:
                final_gen = pool_gen[-1]
                final_gap = pool_gap[-1]
                ax.annotate(f'{final_gap:.2f}%', 
                           xy=(final_gen, final_gap), 
                           xytext=(5, 5), 
                           textcoords='offset points',
                           fontsize=9,
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
                           ha='left', zorder=4)
        
        if popbest_data[bin_name]['generation']:
            pop_gen, pop_gap = zip(*sorted(zip(popbest_data[bin_name]['generation'], popbest_data[bin_name]['gap'])))
            pop_gen = list(pop_gen)
            pop_gap = list(pop_gap)
            ax.plot(pop_gen, pop_gap, 's-', label='EoH', linewidth=2, markersize=5, zorder=3)
            
            if pop_gen and pop_gap:
                final_gen = pop_gen[-1]
                final_gap = pop_gap[-1]
                ax.annotate(f'{final_gap:.2f}%', 
                           xy=(final_gen, final_gap), 
                           xytext=(5, -15), 
                           textcoords='offset points',
                           fontsize=9,
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcoral', alpha=0.7),
                           ha='left', zorder=4)
        
        # Add EoH optimal horizontal line AFTER plotting data (so it's on top and visible)
        if bin_name in eoh_optimal_gaps:
            eoh_gap = eoh_optimal_gaps[bin_name]
            # Draw horizontal line with high zorder to ensure it's visible on top
            ax.axhline(y=eoh_gap, color='red', linestyle='--', linewidth=3, 
                      label=f'EoH Optimal ({eoh_gap:.2f}%)', alpha=0.9, zorder=5)
            print(f"     Added EoH optimal line for {bin_name}: {eoh_gap:.4f}%")
        else:
            print(f"     No EoH optimal data for {bin_name}")
        
        ax.set_xlabel('Round/Generation', fontsize=10)
        ax.set_ylabel('Gap (%)', fontsize=10)
        ax.set_title(f'Size: {bin_label}', fontsize=11)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f" Saved size bin visualization: {out_path}")


# ==================== Main ====================

def visualize_eoh_optimal_only(eoh_optimal_size_df: pd.DataFrame, out_path: str, dpi: int = 150):
    """Visualize only EoH optimal results by size bins."""
    if eoh_optimal_size_df is None or len(eoh_optimal_size_df) == 0:
        print("  No EoH optimal data to visualize")
        return
    
    bins = ['le200', 's200_500', 's500_1000', 'gt1000']
    bin_labels = ['≤200', '200-500', '500-1000', '>1000']
    
    eoh_row = eoh_optimal_size_df[eoh_optimal_size_df['solver'] == 'eoh_optimal']
    if len(eoh_row) == 0:
        print("  No eoh_optimal row found in data")
        return
    
    eoh_gaps = {}
    for b in bins:
        val = eoh_row.iloc[0].get(b)
        if pd.notna(val) and val is not None:
            eoh_gaps[b] = float(val)
    
    if not eoh_gaps:
        print("  No size bin data found")
        return
    
    # Create figure with 4 subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, (bin_name, bin_label) in enumerate(zip(bins, bin_labels)):
        ax = axes[idx]
        
        if bin_name in eoh_gaps:
            eoh_gap = eoh_gaps[bin_name]
            # Draw horizontal line
            ax.axhline(y=eoh_gap, color='red', linestyle='--', linewidth=3, 
                      label=f'EoH Optimal ({eoh_gap:.2f}%)', alpha=0.8)
            ax.set_ylim([0, max(eoh_gap * 1.2, 10)])  # Set reasonable y-axis range
        
        ax.set_xlabel('Round/Generation', fontsize=10)
        ax.set_ylabel('Gap (%)', fontsize=10)
        ax.set_title(f'Size: {bin_label} - EoH Optimal', fontsize=11)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f" Saved EoH optimal only visualization: {out_path}")


def process_single_table(table_path: str, table_name: str, args, eoh_optimal_table_path: str = None):
    table_dir = os.path.dirname(os.path.abspath(table_path))
    out_dir = table_dir
    os.makedirs(out_dir, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f" Processing {table_name}: {table_path}")
    print(f" Output directory: {out_dir}")
    print(f"{'='*60}")
    
    # load table
    df = pd.read_csv(table_path)
    if 'solver' not in df.columns:
        df = pd.read_csv(table_path, index_col=0).reset_index()
        if 'index' in df.columns:
            df = df.rename(columns={'index': 'solver'})
    
    # determine aggregate columns
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    agg_cols = [c for c in df.columns if c not in metadata_cols]
    
    # check if already processed
    has_dataset_groups = any('TSPLIB_' in str(c) or c == 'TSP100' for c in df.columns)
    # Check for processed solvers: pool_best_r*, population_best_g*, or h* (new format)
    has_processed_solvers = any(
        str(s).startswith('pool_best_r') or 
        str(s).startswith('population_best_g') or 
        (str(s).startswith('h') and len(str(s)) > 1 and str(s)[1:].isdigit())
        for s in df['solver'].values if pd.notna(s)
    )
    
    if has_dataset_groups and has_processed_solvers:
        print(" Data is already processed")
        pool_rows_raw = df[df['solver'].str.match(r'^h\d+$', na=False)].copy()
        if len(pool_rows_raw) > 0:
            print("   Recomputing pool_best_per_round with latest logic...")
            pool_df = process_pool_best_per_round(pool_rows_raw, agg_cols, args.metric)
        else:
            pool_mask = df['solver'].str.match(r'^h\d+$', na=False) | df['solver'].str.startswith('pool_best_r', na=False)
            pool_df = df[pool_mask].copy() if len(df[pool_mask]) > 0 else None
            if pool_df is None or len(pool_df) == 0:
                print("    No pool_best rows found, computing from raw data...")
                pool_df = process_pool_best_per_round(df, agg_cols, args.metric)
        
        dataset_df = df
        if pool_df is not None and len(pool_df) > 0:
            pop_df = process_population_latest(df, agg_cols)
            if pop_df is not None and len(pop_df) > 0:
                combined_df = pd.concat([pool_df, pop_df], ignore_index=True)
            else:
                combined_df = pool_df
        else:
            combined_df = df
        
        if pool_df is not None and len(pool_df) > 0:
            pool_path = os.path.join(out_dir, 'pool_best_per_round_gap_table.csv')
            pool_df.to_csv(pool_path, index=False)
            print(f"   Saved: {pool_path}")
            
            print("\n Running aggregate on pool best...")
            _script_dir = os.path.dirname(os.path.abspath(__file__))
            aggregate_script = os.path.join(_script_dir, 'aggregate_comparison_results.py')
            aggregate_output = os.path.join(out_dir, 'pool_best_per_round_gap_table_aggregated.csv')
            
            try:
                result = subprocess.run(
                    [sys.executable, aggregate_script, '--input', pool_path, '--output', aggregate_output],
                    capture_output=True,
                    text=True,
                    check=True
                )
                print(f"   Aggregate completed: {aggregate_output}")
                if result.stdout:
                    print(f"  {result.stdout}")
            except subprocess.CalledProcessError as e:
                print(f"    Aggregate failed: {e}")
                if e.stderr:
                    print(f"  Error: {e.stderr}")
    else:
        print(f" Aggregating across {len(agg_cols)} datasets using metric: {args.metric}")
        
        # Step 2: Compute best per round
        print("\n Step 2: Computing best per round...")
        pool_df = process_pool_best_per_round(df, agg_cols, args.metric)
        pop_df = process_population_latest(df, agg_cols) 
        
        if pool_df is not None:
            pool_path = os.path.join(out_dir, 'pool_best_per_round_gap_table.csv')
            pool_df.to_csv(pool_path, index=False)
            print(f"   Saved: {pool_path}")
            
            print("\n Running aggregate on pool best...")
            _script_dir = os.path.dirname(os.path.abspath(__file__))
            aggregate_script = os.path.join(_script_dir, 'aggregate_comparison_results.py')
            aggregate_output = os.path.join(out_dir, 'pool_best_per_round_gap_table_aggregated.csv')
            
            try:
                result = subprocess.run(
                    [sys.executable, aggregate_script, '--input', pool_path, '--output', aggregate_output],
                    capture_output=True,
                    text=True,
                    check=True
                )
                print(f"   Aggregate completed: {aggregate_output}")
                if result.stdout:
                    print(f"  {result.stdout}")
            except subprocess.CalledProcessError as e:
                print(f"    Aggregate failed: {e}")
                if e.stderr:
                    print(f"  Error: {e.stderr}")
        
        if pool_df is not None and pop_df is not None:
            combined_df = pd.concat([pool_df, pop_df], ignore_index=True)
        elif pool_df is not None:
            combined_df = pool_df
        elif pop_df is not None:
            combined_df = pop_df
        else:
            print(" No data to process")
            return None, None, None
        
        # Step 3: Aggregate by dataset groups
        print("\n Step 3: Aggregating by dataset groups...")
        dataset_df = aggregate_by_dataset_groups(combined_df)
    
    dataset_path = os.path.join(out_dir, 'trend_gap_by_dataset.csv')
    dataset_df.to_csv(dataset_path, index=False)
    print(f"   Saved: {dataset_path}")
    
    eoh_optimal_dataset_df = None
    if eoh_optimal_table_path and os.path.exists(eoh_optimal_table_path):
        try:
            eoh_df = pd.read_csv(eoh_optimal_table_path)
            if 'solver' not in eoh_df.columns:
                eoh_df = pd.read_csv(eoh_optimal_table_path, index_col=0).reset_index()
                if 'index' in eoh_df.columns:
                    eoh_df = eoh_df.rename(columns={'index': 'solver'})
            eoh_optimal_dataset_df = eoh_df
            print("   Loaded EoH optimal table for visualization")
        except Exception as e:
            print(f"    Failed to load EoH optimal table: {e}")
    

    print("\n Step 4: Visualizing trends...")
    visualize_trend_table(combined_df, os.path.join(out_dir, 'trend_visualization.png'), args.dpi)
    
    original_metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    original_agg_cols = [c for c in df.columns if c not in original_metadata_cols]
    
    has_processed = any(str(s).startswith('pool_best_r') or str(s).startswith('population_best_g') or 
                       (str(s).startswith('h') and re.match(r'^h\d+$', str(s)))
                       for s in df['solver'].values if pd.notna(s))
    
    if has_processed:
        dataset_viz_df = df
    else:
        viz_pool_df = process_pool_best_per_round(df, original_agg_cols, args.metric)
        viz_pop_df = process_population_latest(df, original_agg_cols)
        
        if viz_pool_df is not None and viz_pop_df is not None:
            dataset_viz_df = pd.concat([viz_pool_df, viz_pop_df], ignore_index=True)
        elif viz_pool_df is not None:
            dataset_viz_df = viz_pool_df
        elif viz_pop_df is not None:
            dataset_viz_df = viz_pop_df
        else:
            dataset_viz_df = df
    
    visualize_by_dataset_sizes(dataset_viz_df, os.path.join(out_dir, 'trend_by_dataset_sizes.png'), args.dpi, eoh_optimal_dataset_df)
    
    print(f"\n {table_name} processing complete! Outputs saved to: {out_dir}")
    
    return combined_df, dataset_viz_df, eoh_optimal_dataset_df


def main():
    args = parse_args()
    
    # Auto-detect EoH optimal table
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    eoh_optimal_table = args.eoh_optimal_table
    if eoh_optimal_table is None or (eoh_optimal_table and not os.path.exists(eoh_optimal_table)):
        baseline_results_table = os.path.join(_script_dir, 'baseline_results', 'eoh_optimal_gap_table.csv')
        if os.path.exists(baseline_results_table):
            eoh_optimal_table = baseline_results_table
            print(f" Auto-detected EoH optimal table: {baseline_results_table}")
        else:
            eoh_optimal_table = None
    
    gap_combined_df = None
    gap_dataset_viz_df = None
    pop_combined_df = None
    pop_dataset_viz_df = None
    eoh_optimal_dataset_df = None
    
    if args.gap_table:
        gap_combined_df, gap_dataset_viz_df, eoh_optimal_dataset_df = process_single_table(
            args.gap_table, "Gap Table", args, eoh_optimal_table
        )
    
    if args.population_table:
        pop_combined_df, pop_dataset_viz_df, _ = process_single_table(
            args.population_table, "Population Table", args, eoh_optimal_table
        )
    
    if args.gap_table and args.population_table and gap_combined_df is not None and pop_combined_df is not None:
        print(f"\n{'='*60}")
        print(f" Merging visualizations from both tables")
        print(f"{'='*60}")
        
        gap_table_dir = os.path.dirname(os.path.abspath(args.gap_table))
        merged_out_dir = gap_table_dir
        os.makedirs(merged_out_dir, exist_ok=True)
        
        merged_combined_df = pd.concat([gap_combined_df, pop_combined_df], ignore_index=True)
        merged_dataset_viz_df = pd.concat([gap_dataset_viz_df, pop_dataset_viz_df], ignore_index=True)
        
        print("\n Generating merged visualizations...")
        visualize_trend_table(merged_combined_df, os.path.join(merged_out_dir, 'trend_visualization.png'), args.dpi)
        visualize_by_dataset_sizes(merged_dataset_viz_df, os.path.join(merged_out_dir, 'trend_by_dataset_sizes.png'), args.dpi, eoh_optimal_dataset_df)
        
        print(f"\n Merged visualizations saved to: {merged_out_dir}")
    
    print(f"\n{'='*60}")
    print(f" All processing complete!")
    print(f"{'='*60}")




if __name__ == '__main__':
    main()

