#!/usr/bin/env python3
"""Process trend gap table for BP Online: compute best per round, aggregate by dataset, and visualize.

This script combines:
1. Compute best pool/population per round
2. Aggregate by dataset groups (Falkenauer T/U by size, Weibull by size)
3. Visualize trends

Usage:
  python3 process_and_visualize_trends.py \
    --gap_table /path/to/trend_gap_table.csv \
    --out_dir /path/to/output \
    --metric mean \
    --population_table /path/to/population_gap_table.csv  # optional
"""

import os
import sys
import argparse
import re
import subprocess
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def parse_args():
    # Default output directory is the same as script directory
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    # Default EoH optimal table path (in baseline_results/)
    _default_eoh_optimal_table = os.path.join(_script_dir, 'baseline_results', 'eoh_optimal_gap_table.csv')
    
    p = argparse.ArgumentParser(
        description="Process trend gap table for BP Online: compute best per round, aggregate by dataset, and visualize."
    )
    p.add_argument(
        '--gap_table',
        required=True,
        help='Input trend_gap_table.csv path (required: your method\'s gap table)'
    )
    p.add_argument(
        '--population_table',
        type=str,
        default=None,
        help='Optional: Path to CSV file containing population data (if different from gap_table). If not provided, population trend will not be plotted.'
    )
    p.add_argument(
        '--eoh_optimal_table',
        type=str,
        default=_default_eoh_optimal_table,
        help=f'Path to CSV file containing EoH optimal heuristic results (default: {_default_eoh_optimal_table})'
    )
    p.add_argument(
        '--out_dir',
        type=str,
        default=None,
        help='Output directory (default: same directory as gap_table)'
    )
    p.add_argument(
        '--metric',
        type=str,
        default='mean',
        choices=['mean', 'min', 'median'],
        help='Aggregation metric for finding best solver (default: mean)'
    )
    p.add_argument(
        '--dpi',
        type=int,
        default=150,
        help='DPI for output figures (default: 150)'
    )
    return p.parse_args()


# ==================== Step 2: Compute best per round ====================

def is_instance_level_column(col_name: str) -> bool:
    """判断列名是否是逐实例的结果（虽然可能叫_mean_gap，但实际是单个实例）。
    
    新的列名格式（参考文件格式）：
    - Falkenauer_t120_00_random_mean_gap -> 逐实例
    - Hard28_BPP119_random_mean_gap -> 逐实例
    - Weibull_shape1p4_scale30p0_1000_00_random_mean_gap -> 逐实例
    
    数据集级别的列名格式（旧格式，可能不再使用）：
    - Falkenauer_T120_120_mean_gap -> 数据集级别
    - Hard28_mean_gap -> 数据集级别
    
    判断规则：
    1. 如果列名包含order信息（_random, _size_asc, _size_desc），且后面跟着_mean_gap，则是逐实例的
    2. 如果列名是Falkenauer_t/U{size}_{instance_idx}_{order}_mean_gap格式，则是逐实例的
    3. 如果列名是Hard28_{instance_name}_{order}_mean_gap格式，则是逐实例的
    4. 如果列名是Weibull_shape{shape}_scale{scale}_{size}_{instance_idx}_{order}_mean_gap格式，则是逐实例的
    """
    if not col_name.endswith('_mean_gap'):
        return False
    
    base = col_name.replace('_mean_gap', '')
    
    # 检查是否包含order信息（_random, _size_asc, _size_desc）
    has_order = False
    order_patterns = ['_random', '_size_asc', '_size_desc']
    for order_pattern in order_patterns:
        if base.endswith(order_pattern):
            has_order = True
            base_without_order = base[:-len(order_pattern)]
            break
    
    if not has_order:
        # 如果没有order信息，可能是数据集级别的（旧格式）
        return False
    
    # 检查是否是逐实例格式
    # Falkenauer_t120_00_random -> 逐实例
    if base_without_order.startswith('Falkenauer_t') or base_without_order.startswith('Falkenauer_u'):
        # 格式：Falkenauer_t120_00 或 Falkenauer_u250_00
        parts = base_without_order.split('_')
        if len(parts) >= 3:
            # 检查是否有实例编号（通常是两位数字）
            if len(parts) >= 3 and parts[2].isdigit():
                return True
    
    # Hard28_BPP119_random -> 逐实例
    if base_without_order.startswith('Hard28_'):
        # 格式：Hard28_BPP119
        return True
    
    # Weibull_shape1p4_scale30p0_1000_00_random -> 逐实例
    if 'Weibull_shape' in base_without_order and 'scale' in base_without_order:
        # 格式：Weibull_shape1p4_scale30p0_1000_00
        # 检查是否有实例编号
        parts = base_without_order.split('_')
        if len(parts) >= 5:  # Weibull_shape1p4_scale30p0_1000_00
            return True
    
    # HeavyTail_alpha1.2_1000_00_random -> 逐实例
    if base_without_order.startswith('HeavyTail_alpha'):
        # 格式：HeavyTail_alpha1.2_1000_00
        parts = base_without_order.split('_')
        if len(parts) >= 4:  # HeavyTail_alpha1.2_1000_00
            return True
    
    # Mixture2_1000_00_random -> 逐实例
    if base_without_order.startswith('Mixture'):
        # 格式：Mixture2_1000_00
        parts = base_without_order.split('_')
        if len(parts) >= 3:  # Mixture2_1000_00
            return True
    
    return False


def extract_dataset_label_from_instance_column(col_name: str) -> str:
    """从逐实例的列名中提取数据集标签。
    
    例如：
    - Falkenauer_t120_00_random_mean_gap -> Falkenauer_T120_120_random
    - Hard28_BPP119_random_mean_gap -> Hard28_random
    - Weibull_shape1p4_scale30p0_1000_00_random_mean_gap -> Weibull_shape1p4_scale30p0_1000_random
    """
    if not col_name.endswith('_mean_gap'):
        return None
    
    base = col_name.replace('_mean_gap', '')
    
    # 提取order信息
    order = None
    if base.endswith('_random'):
        order = 'random'
        base_without_order = base[:-7]
    elif base.endswith('_size_asc'):
        order = 'size_asc'
        base_without_order = base[:-9]
    elif base.endswith('_size_desc'):
        order = 'size_desc'
        base_without_order = base[:-10]
    else:
        return None
    
    # 根据不同的数据集类型提取数据集标签
    if base_without_order.startswith('Falkenauer_t') or base_without_order.startswith('Falkenauer_u'):
        # Falkenauer_t120_00 -> Falkenauer_T120_120
        parts = base_without_order.split('_')
        if len(parts) >= 3:
            falk_type = parts[1][0].upper()  # 't' -> 'T', 'u' -> 'U'
            size = parts[1][1:]  # 't120' -> '120'
            # 移除实例编号部分（parts[2]）
            return f"Falkenauer_{falk_type}{size}_{size}_{order}"
    
    elif base_without_order.startswith('Hard28_'):
        # Hard28_BPP119 -> Hard28
        return f"Hard28_{order}"
    
    elif 'Weibull_shape' in base_without_order and 'scale' in base_without_order:
        # Weibull_shape1p4_scale30p0_1000_00 -> Weibull_shape1p4_scale30p0_1000
        # 提取shape和scale部分，以及size
        parts = base_without_order.split('_')
        # 找到shape和scale的位置
        shape_idx = None
        scale_idx = None
        size_idx = None
        for i, part in enumerate(parts):
            if part.startswith('shape'):
                shape_idx = i
            elif part.startswith('scale'):
                scale_idx = i
            elif part.isdigit() and scale_idx is not None and i > scale_idx:
                size_idx = i
                break
        
        if shape_idx is not None and scale_idx is not None and size_idx is not None:
            shape_part = parts[shape_idx]
            scale_part = parts[scale_idx]
            size = parts[size_idx]
            return f"Weibull_{shape_part}_{scale_part}_{size}_{order}"
    
    elif base_without_order.startswith('HeavyTail_alpha'):
        # HeavyTail_alpha1.2_1000_00 -> HeavyTail_alpha1.2_1000
        parts = base_without_order.split('_')
        if len(parts) >= 3:
            alpha_part = parts[1]  # alpha1.2
            size = parts[2]  # 1000
            return f"HeavyTail_{alpha_part}_{size}_{order}"
    
    elif base_without_order.startswith('Mixture'):
        # Mixture2_1000_00 -> Mixture2_1000
        parts = base_without_order.split('_')
        if len(parts) >= 2:
            mix_type = parts[0]  # Mixture2
            size = parts[1]  # 1000
            return f"{mix_type}_{size}_{order}"
    
    return None


def extract_pool_round_number(row_name: str) -> int:
    """Extract round number from row name like 'h0', 'h1', etc."""
    if not row_name.startswith('h') or len(row_name) < 2:
        return -1
    match = re.search(r'^h(\d+)$', row_name)
    return int(match.group(1)) if match else -1




def compute_aggregate_score(row: pd.Series, cols: list, metric: str) -> float:
    """Compute aggregate score from a row across specified columns."""
    values = []
    for col in cols:
        val = row.get(col)
        if pd.notna(val) and val is not None:
            try:
                values.append(float(val))
            except (ValueError, TypeError):
                pass
    
    if not values:
        return float('inf')
    
    if metric == 'mean':
        return float(np.mean(values))
    elif metric == 'min':
        return float(np.min(values))
    elif metric == 'median':
        return float(np.median(values))
    else:
        return float(np.mean(values))


def process_pool_best_per_round(df: pd.DataFrame, agg_cols: list, metric: str) -> pd.DataFrame:
    """Process pool: find best solver among h0 to hi for each round i.
    
    直接从trend_gap_table.csv处理，对每一列找到best值，保持所有列名和顺序完全不变。
    """
    pool_rows = df[df['solver'].str.match(r'^h\d+$', na=False)].copy()
    if len(pool_rows) == 0:
        return None
    
    pool_rows['round'] = pool_rows['solver'].apply(extract_pool_round_number)
    pool_rows = pool_rows[pool_rows['round'] >= 0].sort_values('round').reset_index(drop=True)
    
    if len(pool_rows) == 0:
        return None
    
    max_round = int(pool_rows['round'].max())
    
    best_per_round = {}
    
    for target_round in range(max_round + 1):
        candidates = pool_rows[pool_rows['round'] <= target_round].copy()
        if len(candidates) == 0:
            continue
        
        out_row = {
            'solver': f"h{target_round}",
        }
        
        # 直接处理所有agg_cols中的列，保持与trend_gap_table.csv完全一致的列名和顺序
        for col in agg_cols:
            best_val = float('inf')
            for idx, candidate in candidates.iterrows():
                val = candidate.get(col)
                if pd.notna(val) and val is not None:
                    val_float = float(val)
                    if val_float < best_val:
                        best_val = val_float
            if best_val != float('inf'):
                out_row[col] = best_val
        
        best_per_round[target_round] = out_row
    
    # 创建DataFrame，确保列的顺序与输入DataFrame完全一致
    result_df = pd.DataFrame(list(best_per_round.values()))
    # 确保列的顺序：solver列在前，然后是agg_cols中的所有列（按原始顺序）
    # 即使某些列在某些round中没有值，也要保留这些列
    ordered_cols = ['solver'] + agg_cols
    # 确保所有列都存在（如果不存在则添加为NaN）
    for col in ordered_cols:
        if col not in result_df.columns:
            result_df[col] = None
    # 按照ordered_cols的顺序重新排列列
    result_df = result_df[ordered_cols]
    
    return result_df


def extract_dataset_info_from_label(dataset_label: str) -> tuple:
    """Extract dataset type and size from dataset label.
    
    Examples:
        'Falkenauer_T120_120' -> ('Falkenauer_T', 120)
        'Falkenauer_U250_250' -> ('Falkenauer_U', 250)
        'Weibull_5k'         -> ('Weibull', 5000)
        'Weibull_5000'       -> ('Weibull', 5000)
    """
    if dataset_label.startswith('Falkenauer_'):
        # Format: Falkenauer_T120_120 or Falkenauer_U250_250
        # Extract T/U and the first number
        match = re.match(r'Falkenauer_([TU])(\d+)_\d+', dataset_label)
        if match:
            dataset_type = f"Falkenauer_{match.group(1)}"
            size = int(match.group(2))
            return (dataset_type, size)
    elif dataset_label.startswith('Weibull'):
        # 支持 Weibull_5k / Weibull_5000
        size_match = re.search(r'(\d+)k', dataset_label, re.IGNORECASE)
        if size_match:
            return ('Weibull', int(size_match.group(1)) * 1000)
        num_match = re.search(r'Weibull_(\d+)', dataset_label, re.IGNORECASE)
        if num_match:
            return ('Weibull', int(num_match.group(1)))
    
    return (None, None)


def process_population_latest(df: pd.DataFrame, agg_cols: list) -> pd.DataFrame:
    """Process population: use latest generation directly (no best selection).
    
    Note: df should only contain population rows, and agg_cols should match
    the columns available in the target CSV (may have missing columns filled with None).
    """
    pop_rows = df[df['solver'].str.match(r'^h\d+$', na=False)].copy()
    if len(pop_rows) == 0:
        return None
    
    pop_rows['generation'] = pop_rows['solver'].apply(extract_pool_round_number)
    pop_rows = pop_rows[pop_rows['generation'] >= 0].sort_values('generation').reset_index(drop=True)
    
    if len(pop_rows) == 0:
        return None
    
    out_rows = []
    for _, row in pop_rows.iterrows():
        gen = int(row['generation'])
        out_row = {
            'solver': f"population_best_g{gen}",
            'generation': gen,
            'original_name': str(row['solver'])
        }
        
        for col in agg_cols:
            val = row.get(col)
            out_row[col] = float(val) if pd.notna(val) and val is not None else None
        
        out_rows.append(out_row)
    
    return pd.DataFrame(out_rows)


# ==================== Step 3: Aggregate by dataset groups ====================

def extract_dataset_info(col_name: str) -> tuple:
    """Extract dataset type and size from column name.
    
    Examples:
        # 新的逐实例格式（虽然叫_mean_gap，但实际是单个实例）
        'Falkenauer_t120_00_random_mean_gap' -> ('Falkenauer_T', 120)
        'Hard28_BPP119_random_mean_gap' -> ('Hard28', 200)
        'Weibull_shape1p4_scale30p0_1000_00_random_mean_gap' -> ('Weibull', 1000)
        'HeavyTail_alpha1.2_1000_00_random_mean_gap' -> ('HeavyTail_alpha1.2', 1000)
        'Mixture2_1000_00_random_mean_gap' -> ('Mixture2', 1000)
        
        # 旧的数据集级别格式
        'Falkenauer_T120_120_mean_gap' -> ('Falkenauer_T', 120)
        'Falkenauer_U250_250_mean_gap' -> ('Falkenauer_U', 250)
        'Weibull_5k_mean_gap' -> ('Weibull', 5000)
        'HeavyTail_alpha1.2_50_mean_gap' -> ('HeavyTail_alpha1.2', 50)
        'Mixture1_100_mean_gap' -> ('Mixture1', 100)
        'Falkenauer_T_60' -> ('Falkenauer_T', 60)  # aggregated format
        'Falkenauer_T120_120_Falkenauer_t120_00_gap' -> ('Falkenauer_T', 120)  # instance-level (old format)
    """
    # 首先检查是否是新的逐实例格式
    if is_instance_level_column(col_name):
        base = col_name.replace('_mean_gap', '')
        
        # 提取order信息
        if base.endswith('_random'):
            base_without_order = base[:-7]
        elif base.endswith('_size_asc'):
            base_without_order = base[:-9]
        elif base.endswith('_size_desc'):
            base_without_order = base[:-10]
        else:
            base_without_order = base
        
        # Falkenauer_t120_00_random -> ('Falkenauer_T', 120)
        if base_without_order.startswith('Falkenauer_t') or base_without_order.startswith('Falkenauer_u'):
            parts = base_without_order.split('_')
            if len(parts) >= 2:
                falk_type = parts[1][0].upper()  # 't' -> 'T', 'u' -> 'U'
                size_str = parts[1][1:]  # 't120' -> '120'
                try:
                    size = int(size_str)
                    return (f"Falkenauer_{falk_type}", size)
                except ValueError:
                    pass
        
        # Hard28_BPP119_random -> ('Hard28', 200)
        elif base_without_order.startswith('Hard28_'):
            return ('Hard28', 200)
        
        # Weibull_shape1p4_scale30p0_1000_00_random -> ('Weibull', 1000)
        elif 'Weibull_shape' in base_without_order and 'scale' in base_without_order:
            parts = base_without_order.split('_')
            # 找到size的位置（在scale之后的下一个数字）
            for i, part in enumerate(parts):
                if part.startswith('scale') and i + 1 < len(parts):
                    try:
                        size = int(parts[i + 1])
                        return ('Weibull', size)
                    except (ValueError, IndexError):
                        pass
        
        # HeavyTail_alpha1.2_1000_00_random -> ('HeavyTail_alpha1.2', 1000)
        elif base_without_order.startswith('HeavyTail_alpha'):
            parts = base_without_order.split('_')
            if len(parts) >= 3:
                alpha_part = parts[1]  # alpha1.2
                try:
                    size = int(parts[2])
                    alpha = alpha_part.replace('alpha', '')
                    return (f"HeavyTail_alpha{alpha}", size)
                except (ValueError, IndexError):
                    pass
        
        # Mixture2_1000_00_random -> ('Mixture2', 1000)
        elif base_without_order.startswith('Mixture'):
            parts = base_without_order.split('_')
            if len(parts) >= 2:
                mix_type = parts[0]  # Mixture2
                try:
                    size = int(parts[1])
                    return (mix_type, size)
                except (ValueError, IndexError):
                    pass
    # Check for aggregated format first: Falkenauer_T_60, Falkenauer_U_120, Weibull_5k
    if col_name.startswith('Falkenauer_'):
        parts = col_name.split('_')
        if len(parts) >= 3 and parts[1] in ['T', 'U']:
            # Format: Falkenauer_T_60 or Falkenauer_U_120
            dataset_type = f"Falkenauer_{parts[1]}"
            try:
                size = int(parts[2])
                return (dataset_type, size)
            except (ValueError, IndexError):
                pass
    elif col_name.startswith('Weibull_'):
        # Format: Weibull_5k, Weibull_5000, Weibull_shape1p4_scale30p0_50
        # Or: Weibull_1000_100_Weibull_shape1p4_scale30p0_1000_00_gap
        parts = col_name.split('_')
        if len(parts) >= 2:
            # Check for Weibull_shape1p4_scale30p0_* format first (may appear after Weibull_ prefix)
            shape_match = re.search(r'Weibull_shape([\dp]+)_scale([\dp]+)_(\d+)', col_name)
            if shape_match:
                try:
                    size = int(shape_match.group(3))
                    return ('Weibull', size)
                except (ValueError, IndexError):
                    pass
            
            size_str = parts[1]
            # Old style: Weibull_5k
            size_match = re.match(r'(\d+)k', size_str, re.IGNORECASE)
            if size_match:
                return ('Weibull', int(size_match.group(1)) * 1000)
            # New style: Weibull_500 / Weibull_5000
            num_match = re.match(r'(\d+)', size_str)
            if num_match:
                return ('Weibull', int(num_match.group(1)))
    elif col_name.startswith('HeavyTail_'):
        # Format: HeavyTail_alpha1.2_50 or HeavyTail_alpha1p2_50
        parts = col_name.split('_')
        if len(parts) >= 3:
            alpha_part = parts[1]  # alpha1.2 or alpha1p2
            alpha_match = re.match(r'alpha([\d.]+)', alpha_part.replace('p', '.'))
            if alpha_match:
                try:
                    size = int(parts[2])
                    alpha = alpha_match.group(1)
                    dataset_type = f"HeavyTail_alpha{alpha}"
                    return (dataset_type, size)
                except (ValueError, IndexError):
                    pass
    elif col_name.startswith('Mixture'):
        # Format: Mixture1_50, Mixture2_100, Mixture3_500
        parts = col_name.split('_')
        if len(parts) >= 2:
            mix_match = re.match(r'Mixture(\d+)', parts[0])
            if mix_match:
                try:
                    mix_num = mix_match.group(1)
                    size = int(parts[1])
                    dataset_type = f"Mixture{mix_num}"
                    return (dataset_type, size)
                except (ValueError, IndexError):
                    pass
    
    # Check for mean_gap columns (dataset-level aggregation)
    if '_mean_gap' in col_name:
        base = col_name.replace('_mean_gap', '')
    elif '_mean_bins' in col_name:
        base = col_name.replace('_mean_bins', '')
    else:
        # Instance-level column, extract dataset from prefix
        # Format: Falkenauer_T120_120_Falkenauer_t120_00_gap
        parts = col_name.split('_')
        if len(parts) >= 3 and parts[0] == 'Falkenauer':
            # Extract type (T or U) and size from parts[1] (e.g., 'T120' -> 'T', 120)
            type_size = parts[1]
            type_match = re.match(r'([TU])(\d+)', type_size)
            if type_match:
                dataset_type = f"Falkenauer_{type_match.group(1)}"
                size = int(type_match.group(2))
                return (dataset_type, size)
        elif len(parts) >= 2 and parts[0] == 'Weibull':
            # Weibull_5k_test_0_gap or Weibull_shape1p4_scale30p0_50_00_gap
            # Or: Weibull_1000_100_Weibull_shape1p4_scale30p0_1000_00_gap
            # Check for Weibull_shape1p4_scale30p0_* format first
            shape_match = re.search(r'Weibull_shape([\dp]+)_scale([\dp]+)_(\d+)', col_name)
            if shape_match:
                try:
                    size = int(shape_match.group(3))
                    return ('Weibull', size)
                except (ValueError, IndexError):
                    pass
            
            size_str = parts[1]
            size_match = re.match(r'(\d+)k', size_str, re.IGNORECASE)
            if size_match:
                return ('Weibull', int(size_match.group(1)) * 1000)
        elif len(parts) >= 2 and parts[0] == 'HeavyTail':
            # HeavyTail_alpha1.2_50_00_gap or HeavyTail_alpha1p2_50_00_gap
            alpha_part = parts[1]  # alpha1.2 or alpha1p2
            alpha_match = re.match(r'alpha([\d.]+)', alpha_part.replace('p', '.'))
            if alpha_match and len(parts) >= 3:
                try:
                    size = int(parts[2])
                    alpha = alpha_match.group(1)
                    dataset_type = f"HeavyTail_alpha{alpha}"
                    return (dataset_type, size)
                except (ValueError, IndexError):
                    pass
        elif len(parts) >= 2 and parts[0].startswith('Mixture'):
            # Mixture1_50_00_gap
            mix_match = re.match(r'Mixture(\d+)', parts[0])
            if mix_match and len(parts) >= 2:
                try:
                    mix_num = mix_match.group(1)
                    size = int(parts[1])
                    dataset_type = f"Mixture{mix_num}"
                    return (dataset_type, size)
                except (ValueError, IndexError):
                    pass
        return (None, None)
    
    # Parse base name for mean_gap columns
    # Format: Falkenauer_T120_120, Weibull_5k, HeavyTail_alpha1.2_50, Mixture1_100
    if base.startswith('Falkenauer_'):
        # Falkenauer_T120_120 -> extract T and 120
        parts = base.split('_')
        if len(parts) >= 2:
            # parts[0] = 'Falkenauer', parts[1] = 'T120' or 'U120'
            type_size = parts[1]
            type_match = re.match(r'([TU])(\d+)', type_size)
            if type_match:
                dataset_type = f"Falkenauer_{type_match.group(1)}"
                size = int(type_match.group(2))
                return (dataset_type, size)
    elif base.startswith('Weibull'):
        # Weibull_5k or Weibull_shape1p4_scale30p0_50
        # Or: Weibull_1000_100_Weibull_shape1p4_scale30p0_1000_00 (from mean_gap)
        # Check for Weibull_shape1p4_scale30p0_* format first
        shape_match = re.search(r'Weibull_shape([\dp]+)_scale([\dp]+)_(\d+)', base)
        if shape_match:
            try:
                size = int(shape_match.group(3))
                return ('Weibull', size)
            except (ValueError, IndexError):
                pass
        
        size_match = re.search(r'(\d+)k', base, re.IGNORECASE)
        if size_match:
            return ('Weibull', int(size_match.group(1)) * 1000)
    elif base.startswith('HeavyTail_'):
        # HeavyTail_alpha1.2_50
        parts = base.split('_')
        if len(parts) >= 3:
            alpha_part = parts[1]
            alpha_match = re.match(r'alpha([\d.]+)', alpha_part.replace('p', '.'))
            if alpha_match:
                try:
                    size = int(parts[2])
                    alpha = alpha_match.group(1)
                    dataset_type = f"HeavyTail_alpha{alpha}"
                    return (dataset_type, size)
                except (ValueError, IndexError):
                    pass
    elif base.startswith('Mixture'):
        # Mixture1_50
        parts = base.split('_')
        if len(parts) >= 2:
            mix_match = re.match(r'Mixture(\d+)', parts[0])
            if mix_match:
                try:
                    mix_num = mix_match.group(1)
                    size = int(parts[1])
                    dataset_type = f"Mixture{mix_num}"
                    return (dataset_type, size)
                except (ValueError, IndexError):
                    pass
    # Hard28 等"整包"数据集：没有 size 信息，用一个固定 size 占位
    elif base.startswith('Hard28'):
        # Hard28_mean_gap -> ('Hard28', 200)  # 28 个实例，n 在 160~200 之间，这里用 200 作为代表规模
        return ('Hard28', 200)
    
    return (None, None)


def aggregate_by_dataset_groups(df: pd.DataFrame) -> pd.DataFrame:
    """Aggregate gap table by dataset groups.
    
    Groups:
    - Falkenauer_T: aggregate by size (60, 120, 250, 500, 1000)
    - Falkenauer_U: aggregate by size (120, 250, 500, 1000)
    - Weibull: aggregate by size (50, 100, 500, 1000, 5000)
    - HeavyTail_alpha*: aggregate by size (50, 100, 500, 1000, 5000) for each alpha
    - Mixture*: aggregate by size (50, 100, 500, 1000, 5000) for each mixture type
    - Hard28: single dataset (no size grouping)
    
    The input columns should be _mean_gap columns (dataset-level), which are already aggregated.
    This function groups them by dataset type and size, creating one column per (type, size) combination.
    """
    # Determine columns (exclude solver and metadata columns)
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    # Prefer _mean_gap columns (dataset-level aggregation)
    mean_gap_cols = [c for c in all_cols if c.endswith('_mean_gap')]
    if not mean_gap_cols:
        # Fallback to all columns if no _mean_gap columns found
        mean_gap_cols = all_cols
    
    # Group columns by dataset type and size
    # Use a more flexible structure to handle dynamic dataset types
    dataset_groups = {}  # {dataset_type: {size: [cols]}}
    
    for col in mean_gap_cols:
        dataset_type, size = extract_dataset_info(col)
        if dataset_type and size:
            if dataset_type not in dataset_groups:
                dataset_groups[dataset_type] = {}
            if size not in dataset_groups[dataset_type]:
                dataset_groups[dataset_type][size] = []
            dataset_groups[dataset_type][size].append(col)
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # For each dataset group and size, use the mean_gap value directly
        # (since _mean_gap columns are already aggregated per dataset)
        for dataset_type, size_groups in dataset_groups.items():
            for size, cols in size_groups.items():
                if not cols:
                    continue
                
                # Since _mean_gap columns are already aggregated, we just take the value
                # If there are multiple columns for the same (type, size), take the mean
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                # Create column name for aggregated value
                # For Weibull, use k notation for large sizes
                if dataset_type == 'Weibull':
                    size_label = f"{size//1000}k" if size >= 1000 else str(size)
                    col_name = f"{dataset_type}_{size_label}"
                else:
                    col_name = f"{dataset_type}_{size}"
                
                # Use the mean of values (should be just one value per dataset, but handle multiple)
                out_row[col_name] = None if not values else float(np.mean(values))
        
        out_rows.append(out_row)
    
    # Build output columns in sorted order
    # Sort dataset types: Falkenauer first, then others alphabetically
    output_cols = ['solver']
    sorted_dataset_types = sorted(dataset_groups.keys(), key=lambda x: (
        0 if x.startswith('Falkenauer') else 1 if x.startswith('Weibull') else 2,
        x
    ))
    for dataset_type in sorted_dataset_types:
        if dataset_type in dataset_groups:
            for size in sorted(dataset_groups[dataset_type].keys()):
                if dataset_type == 'Weibull':
                    size_label = f"{size//1000}k" if size >= 1000 else str(size)
                    output_cols.append(f"{dataset_type}_{size_label}")
                else:
                    output_cols.append(f"{dataset_type}_{size}")
    
    return pd.DataFrame(out_rows, columns=output_cols)


# ==================== Step 4: Visualize ====================

def extract_generation(row_name: str) -> tuple:
    """Extract generation info from row name. Returns (type, gen_num)."""
    # Check for h0, h1, h2 format (pool best)
    if row_name.startswith('h') and len(row_name) > 1:
        match = re.search(r'^h(\d+)$', row_name)
        if match:
            return ('pool', int(match.group(1)))
    # Also support old format pool_best_r* for backward compatibility
    if row_name.startswith('pool_best_r'):
        match = re.search(r'pool_best_r(\d+)', row_name)
        if match:
            return ('pool', int(match.group(1)))
    elif row_name.startswith('population_best_g'):
        match = re.search(r'population_best_g(\d+)', row_name)
        if match:
            return ('popbest', int(match.group(1)))
    return (None, None)


def visualize_trend_table(df: pd.DataFrame, out_path: str, dpi: int = 150):
    """Visualize trend table: pool vs population across dataset groups.
    
    Supports all dataset types dynamically, including:
    - Falkenauer_T, Falkenauer_U
    - Weibull
    - HeavyTail_alpha*, Mixture*
    - Hard28
    """
    # Find dataset columns
    all_cols = [c for c in df.columns if c not in ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']]
    
    # Group columns by dataset type dynamically using extract_dataset_info
    dataset_cols = {}  # {dataset_type: [cols]}
    
    for col in all_cols:
        dataset_type, size = extract_dataset_info(col)
        if dataset_type:
            if dataset_type not in dataset_cols:
                dataset_cols[dataset_type] = []
            dataset_cols[dataset_type].append(col)
    
    pool_data = {dt: {'generation': [], 'gap': []} for dt in dataset_cols.keys()}
    popbest_data = {dt: {'generation': [], 'gap': []} for dt in dataset_cols.keys()}
    
    for _, row in df.iterrows():
        solver = row['solver']
        typ, gen = extract_generation(solver)
        
        if typ is None:
            continue
        
        # Compute average gap for each dataset type
        for dataset_type, cols in dataset_cols.items():
            if not cols:
                continue
            
            values = []
            for col in cols:
                val = row.get(col)
                if pd.notna(val) and val is not None:
                    values.append(float(val))
            
            if values:
                avg_gap = np.mean(values)
                
                if typ == 'pool':
                    pool_data[dataset_type]['generation'].append(gen)
                    pool_data[dataset_type]['gap'].append(avg_gap)
                elif typ == 'popbest':
                    popbest_data[dataset_type]['generation'].append(gen)
                    popbest_data[dataset_type]['gap'].append(avg_gap)
    
    # Create figure with subplots for each dataset type
    # Only include dataset types that have columns
    valid_dataset_types = [dt for dt in dataset_cols.keys() if dataset_cols[dt]]
    # Sort dataset types: Falkenauer first, then Weibull, then others alphabetically
    valid_dataset_types = sorted(valid_dataset_types, key=lambda x: (
        0 if x.startswith('Falkenauer') else 1 if x.startswith('Weibull') else 2,
        x
    ))
    n_datasets = len(valid_dataset_types)
    if n_datasets == 0:
        print("⚠️  No dataset columns found for visualization")
        return
    
    fig, axes = plt.subplots(1, n_datasets, figsize=(5 * n_datasets, 5))
    if n_datasets == 1:
        axes = [axes]
    
    for idx, dataset_type in enumerate(valid_dataset_types):
        ax = axes[idx]
        
        if pool_data[dataset_type]['generation']:
            pool_gen, pool_gap = zip(*sorted(zip(pool_data[dataset_type]['generation'], pool_data[dataset_type]['gap'])))
            pool_gen = list(pool_gen)
            pool_gap = list(pool_gap)
            ax.plot(pool_gen, pool_gap, 'o-', label='Co-evolution', linewidth=2, markersize=6)
            
            # 标注最终值
            if pool_gen and pool_gap:
                final_gen = pool_gen[-1]
                final_gap = pool_gap[-1]
                ax.annotate(f'{final_gap:.2f}%', 
                           xy=(final_gen, final_gap), 
                           xytext=(5, 5), 
                           textcoords='offset points',
                           fontsize=9,
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
                           ha='left')
        
        if popbest_data[dataset_type]['generation']:
            pop_gen, pop_gap = zip(*sorted(zip(popbest_data[dataset_type]['generation'], popbest_data[dataset_type]['gap'])))
            pop_gen = list(pop_gen)
            pop_gap = list(pop_gap)
            ax.plot(pop_gen, pop_gap, 's-', label='EoH', linewidth=2, markersize=6)
            
            # 标注最终值
            if pop_gen and pop_gap:
                final_gen = pop_gen[-1]
                final_gap = pop_gap[-1]
                ax.annotate(f'{final_gap:.2f}%', 
                           xy=(final_gen, final_gap), 
                           xytext=(5, -15), 
                           textcoords='offset points',
                           fontsize=9,
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcoral', alpha=0.7),
                           ha='left')
        
        ax.set_xlabel('Round/Generation', fontsize=11)
        ax.set_ylabel('Gap (%)', fontsize=11)
        ax.set_title(dataset_type, fontsize=12)
        # Only show legend if there are labeled artists
        handles, labels = ax.get_legend_handles_labels()
        if handles:
            ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f"✅ Saved trend visualization: {out_path}")


def visualize_by_dataset_sizes(df: pd.DataFrame, out_dir: str, dpi: int = 150, eoh_optimal_df: pd.DataFrame = None):
    """Visualize trend by dataset sizes - all sizes on one canvas.
    
    Creates one figure with all dataset sizes as subplots.
    Supports EoH optimal, Best Fit, and First Fit baselines.
    """
    # Find dataset columns
    all_cols = [c for c in df.columns if c not in ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']]
    
    # Group columns by dataset type and size (dynamic, supports all dataset types)
    dataset_sizes = {}  # {dataset_type: [sizes]}
    
    for col in all_cols:
        dataset_type, size = extract_dataset_info(col)
        if dataset_type and size:
            if dataset_type not in dataset_sizes:
                dataset_sizes[dataset_type] = []
            if size not in dataset_sizes[dataset_type]:
                dataset_sizes[dataset_type].append(size)
    
    # Sort sizes for each dataset type
    for dt in dataset_sizes:
        dataset_sizes[dt] = sorted(dataset_sizes[dt])
    
    # Extract baseline values (eoh_optimal, best_fit, first_fit) from baseline DataFrames
    baseline_gaps = {}  # {baseline_name: {col_name: gap_value}}
    baseline_names = ['eoh_optimal', 'best_fit', 'first_fit']
    
    # Try to load all baseline files
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    # Try baseline_results first, then fallback to old location for backward compatibility
    baseline_results_dir = os.path.join(_script_dir, 'baseline_results')
    baseline_files = {
        'eoh_optimal': os.path.join(baseline_results_dir, 'eoh_optimal_gap_by_dataset.csv'),
        'best_fit': os.path.join(baseline_results_dir, 'best_fit_gap_table.csv'),
        'first_fit': os.path.join(baseline_results_dir, 'first_fit_gap_table.csv')
    }
    
    # Fallback to old location if not found in baseline_results
    for baseline_name in baseline_files:
        if not os.path.exists(baseline_files[baseline_name]):
            old_path = os.path.join(_script_dir, f'{baseline_name}_gap_by_dataset.csv')
            if os.path.exists(old_path):
                baseline_files[baseline_name] = old_path
    
    # Load baselines from by-dataset CSV files (统一使用 *_gap_by_dataset.csv)
    for baseline_name in baseline_names:
        baseline_data = None
        # 统一从 *_gap_by_dataset.csv 加载，不再特殊对待 eoh_optimal
        baseline_path = baseline_files.get(baseline_name)
        if baseline_path and os.path.exists(baseline_path):
            try:
                baseline_data = pd.read_csv(baseline_path)
                if 'solver' not in baseline_data.columns:
                    baseline_data = pd.read_csv(baseline_path, index_col=0).reset_index()
                    if 'index' in baseline_data.columns:
                        baseline_data = baseline_data.rename(columns={'index': 'solver'})
            except Exception as e:
                print(f"  ⚠️  Failed to load {baseline_name} table: {e}")
                baseline_data = None
        
        # Process baseline data if loaded
        if baseline_data is not None:
            baseline_row = baseline_data[baseline_data['solver'] == baseline_name]
            if len(baseline_row) > 0:
                baseline_gaps[baseline_name] = {}
                baseline_all_cols = [c for c in baseline_data.columns if c not in ['solver']]
                
                # Group columns by dataset type and size
                baseline_dataset_groups = {}  # {dataset_type: {size: [cols]}}
                
                for col in baseline_all_cols:
                    dataset_type, size = extract_dataset_info(col)
                    if dataset_type and size:
                        if dataset_type not in baseline_dataset_groups:
                            baseline_dataset_groups[dataset_type] = {}
                        if size not in baseline_dataset_groups[dataset_type]:
                            baseline_dataset_groups[dataset_type][size] = []
                        baseline_dataset_groups[dataset_type][size].append(col)
                
                # Aggregate baseline values by type and size
                for dataset_type, size_groups in baseline_dataset_groups.items():
                    for size, cols in size_groups.items():
                        values = []
                        for col in cols:
                            val = baseline_row.iloc[0].get(col)
                            if pd.notna(val) and val is not None:
                                values.append(float(val))
                        
                        if values:
                            mean_gap = float(np.mean(values))
                            # Create column name matching our format
                            if dataset_type == 'Weibull':
                                size_label = f"{size//1000}k" if size >= 1000 else str(size)
                                col_name = f"{dataset_type}_{size_label}"
                            else:
                                col_name = f"{dataset_type}_{size}"
                            baseline_gaps[baseline_name][col_name] = mean_gap
                if baseline_gaps[baseline_name]:
                    print(f"  ✅ Loaded {baseline_name} baseline ({len(baseline_gaps[baseline_name])} dataset groups)")
            else:
                print(f"  ⚠️  {baseline_name} row not found in baseline table")
    
    # Collect all sizes to plot (all dataset types)
    all_sizes_to_plot = []
    # Sort dataset types: Falkenauer first, then others alphabetically
    sorted_dataset_types = sorted(dataset_sizes.keys(), key=lambda x: (
        0 if x.startswith('Falkenauer') else 1 if x.startswith('Weibull') else 2,
        x
    ))
    for dataset_type in sorted_dataset_types:
        if dataset_type in dataset_sizes and dataset_sizes[dataset_type]:
            for size in dataset_sizes[dataset_type]:
                if dataset_type == 'Weibull':
                    size_label = f"{size//1000}k" if size >= 1000 else str(size)
                    col_name = f"{dataset_type}_{size_label}"
                    display_label = f"{dataset_type}_{size_label}"
                else:
                    col_name = f"{dataset_type}_{size}"
                    display_label = f"{dataset_type}_{size}"
                all_sizes_to_plot.append((dataset_type, size, col_name, display_label))
    
    if not all_sizes_to_plot:
        print("⚠️  No dataset sizes found for visualization")
        return
    
    # Calculate grid layout: try to make it roughly square
    n_plots = len(all_sizes_to_plot)
    n_cols = int(np.ceil(np.sqrt(n_plots)))
    n_rows = int(np.ceil(n_plots / n_cols))
    
    # Create figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows))
    if n_plots == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes if isinstance(axes, np.ndarray) else [axes]
    else:
        axes = axes.flatten()
    
    # Plot each size
    for idx, (dataset_type, size, col_name, display_label) in enumerate(all_sizes_to_plot):
        ax = axes[idx]
        
        # Collect data for this specific size
        pool_data = {'generation': [], 'gap': []}
        popbest_data = {'generation': [], 'gap': []}
        
        for _, row in df.iterrows():
            solver = row['solver']
            typ, gen = extract_generation(solver)
            
            if typ is None:
                continue
            
            val = row.get(col_name)
            if pd.notna(val) and val is not None:
                if typ == 'pool':
                    pool_data['generation'].append(gen)
                    pool_data['gap'].append(float(val))
                elif typ == 'popbest':
                    popbest_data['generation'].append(gen)
                    popbest_data['gap'].append(float(val))
        
        # Plot pool data
        if pool_data['generation']:
            pool_gen, pool_gap = zip(*sorted(zip(pool_data['generation'], pool_data['gap'])))
            pool_gen = list(pool_gen)
            pool_gap = list(pool_gap)
            ax.plot(pool_gen, pool_gap, 'o-', label='Co-evolution', linewidth=2, markersize=5)
            
            # 标注最终值
            if pool_gen and pool_gap:
                final_gen = pool_gen[-1]
                final_gap = pool_gap[-1]
                ax.annotate(f'{final_gap:.2f}%', 
                           xy=(final_gen, final_gap), 
                           xytext=(5, 5), 
                           textcoords='offset points',
                           fontsize=8,
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7),
                           ha='left')
        
        # Plot population data
        if popbest_data['generation']:
            pop_gen, pop_gap = zip(*sorted(zip(popbest_data['generation'], popbest_data['gap'])))
            pop_gen = list(pop_gen)
            pop_gap = list(pop_gap)
            ax.plot(pop_gen, pop_gap, 's-', label='EoH', linewidth=2, markersize=5)
            
            # 标注最终值
            if pop_gen and pop_gap:
                final_gen = pop_gen[-1]
                final_gap = pop_gap[-1]
                ax.annotate(f'{final_gap:.2f}%', 
                           xy=(final_gen, final_gap), 
                           xytext=(5, -15), 
                           textcoords='offset points',
                           fontsize=8,
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightcoral', alpha=0.7),
                           ha='left')
        
        # Add baseline horizontal lines if available
        baseline_colors = {
            'eoh_optimal': 'red',
            'best_fit': 'green',
            'first_fit': 'orange'
        }
        baseline_styles = {
            'eoh_optimal': '--',
            'best_fit': '-.',
            'first_fit': ':'
        }
        baseline_labels = {
            'eoh_optimal': 'EoH Optimal',
            'best_fit': 'Best Fit',
            'first_fit': 'First Fit'
        }
        
        for baseline_name in ['eoh_optimal', 'best_fit', 'first_fit']:
            if baseline_name in baseline_gaps and col_name in baseline_gaps[baseline_name]:
                baseline_gap = baseline_gaps[baseline_name][col_name]
                color = baseline_colors.get(baseline_name, 'gray')
                linestyle = baseline_styles.get(baseline_name, '--')
                label = f"{baseline_labels.get(baseline_name, baseline_name)} ({baseline_gap:.2f}%)"
                ax.axhline(y=baseline_gap, color=color, linestyle=linestyle, linewidth=2, 
                          label=label, alpha=0.8, zorder=5)
        
        ax.set_xlabel('Round/Generation', fontsize=10)
        ax.set_ylabel('Gap (%)', fontsize=10)
        ax.set_title(display_label, fontsize=11)
        # Only show legend if there are labeled artists
        handles, labels = ax.get_legend_handles_labels()
        if handles:
            ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(n_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    out_path = os.path.join(out_dir, 'trend_by_dataset_sizes.png')
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f"✅ Saved: {out_path}")


# ==================== Main ====================

def process_single_table(table_path: str, table_name: str, args, eoh_optimal_table_path: str = None):
    """处理单个表格（gap table 或 population table）的完整流程"""
    # 确定输出目录（table所在的test文件夹）
    table_dir = os.path.dirname(os.path.abspath(table_path))
    out_dir = table_dir
    os.makedirs(out_dir, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f"📊 Processing {table_name}: {table_path}")
    print(f"📁 Output directory: {out_dir}")
    print(f"{'='*60}")
    
    # 加载表格
    df = pd.read_csv(table_path)
    if 'solver' not in df.columns:
        df = pd.read_csv(table_path, index_col=0).reset_index()
        if 'index' in df.columns:
            df = df.rename(columns={'index': 'solver'})
    
    # 确定聚合列
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    agg_cols = [c for c in df.columns if c not in metadata_cols]
    
    # 检查是否已经处理过（聚合后的格式）
    has_dataset_groups = any(c.startswith('Falkenauer_T_') or c.startswith('Falkenauer_U_') or c.startswith('Weibull_') 
                            or c.startswith('HeavyTail_') or c.startswith('Mixture')
                            for c in df.columns)
    # 检查是否有pool_best_r或population_best_g格式的solver（已处理过的格式）
    has_processed_solvers = any(str(s).startswith('pool_best_r') or str(s).startswith('population_best_g')
                                 for s in df['solver'].values if pd.notna(s))
    # 检查是否有h0, h1, h2格式的solver（原始pool solvers，需要处理）
    has_pool_solvers = any(str(s).startswith('h') and re.match(r'^h\d+$', str(s))
                          for s in df['solver'].values if pd.notna(s))
    
    # 如果已经有聚合后的格式和已处理的solvers，直接使用
    # 如果有原始pool solvers但没有pool_best_r，需要生成pool best per round
    if has_dataset_groups and has_processed_solvers:
        print("✅ Data is already processed, using directly")
        dataset_df = df
        combined_df = df
        # Check for h0, h1, h2 format (new) or pool_best_r format (old)
        pool_mask = df['solver'].str.match(r'^h\d+$', na=False) | df['solver'].str.startswith('pool_best_r', na=False)
        pool_df = df[pool_mask].copy() if len(df[pool_mask]) > 0 else None
    else:
        print(f"📏 Aggregating across {len(agg_cols)} datasets using metric: {args.metric}")
        
        # Step 2: Compute best per round
        print("\n📦 Step 2: Computing best per round...")
        pool_df = process_pool_best_per_round(df, agg_cols, args.metric)
        pop_df = process_population_latest(df, agg_cols)  # 从同一个表格中提取population数据
        
        # 保存 pool_best_per_round_gap_table.csv
        if pool_df is not None:
            pool_path = os.path.join(out_dir, 'pool_best_per_round_gap_table.csv')
            pool_df.to_csv(pool_path, index=False)
            print(f"  ✅ Saved: {pool_path}")
            
            # 对 pool best 运行 aggregate
            print("\n📊 Running aggregate on pool best...")
            _script_dir = os.path.dirname(os.path.abspath(__file__))
            aggregate_script = os.path.join(_script_dir, 'aggregate_comparison_results.py')
            aggregate_output = os.path.join(out_dir, 'pool_best_per_round_gap_table_aggregated.csv')
            
            try:
                result = subprocess.run(
                    [sys.executable, aggregate_script, '--input', pool_path, '--output', aggregate_output],
                    capture_output=True,
                    text=True,
                    check=True
                )
                print(f"  ✅ Aggregate completed: {aggregate_output}")
                if result.stdout:
                    print(f"  {result.stdout}")
            except subprocess.CalledProcessError as e:
                print(f"  ⚠️  Aggregate failed: {e}")
                if e.stderr:
                    print(f"  Error: {e.stderr}")
        
        # 合并用于聚合
        if pool_df is not None and pop_df is not None:
            combined_df = pd.concat([pool_df, pop_df], ignore_index=True)
        elif pool_df is not None:
            combined_df = pool_df
        elif pop_df is not None:
            combined_df = pop_df
        else:
            print("❌ No data to process")
            return None, None, None
        
        # Step 3: Aggregate by dataset groups
        print("\n📊 Step 3: Aggregating by dataset groups...")
        dataset_df = aggregate_by_dataset_groups(combined_df)
    
    # 保存 trend_gap_by_dataset.csv
    dataset_path = os.path.join(out_dir, 'trend_gap_by_dataset.csv')
    dataset_df.to_csv(dataset_path, index=False)
    print(f"  ✅ Saved: {dataset_path}")
    
    # 加载 EoH optimal 数据
    eoh_optimal_dataset_df = None
    if eoh_optimal_table_path and os.path.exists(eoh_optimal_table_path):
        try:
            eoh_df = pd.read_csv(eoh_optimal_table_path)
            if 'solver' not in eoh_df.columns:
                eoh_df = pd.read_csv(eoh_optimal_table_path, index_col=0).reset_index()
                if 'index' in eoh_df.columns:
                    eoh_df = eoh_df.rename(columns={'index': 'solver'})
            eoh_optimal_dataset_df = eoh_df
            print("  ✅ Loaded EoH optimal table for visualization")
        except Exception as e:
            print(f"  ⚠️  Failed to load EoH optimal table: {e}")
    
    # 可视化
    print("\n🎨 Step 4: Visualizing trends...")
    visualize_trend_table(combined_df, os.path.join(out_dir, 'trend_visualization.png'), args.dpi)
    
    # 准备数据集可视化数据
    original_metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    original_agg_cols = [c for c in df.columns if c not in original_metadata_cols]
    
    has_processed = any(str(s).startswith('pool_best_r') or str(s).startswith('population_best_g') or 
                       (str(s).startswith('h') and re.match(r'^h\d+$', str(s)))
                       for s in df['solver'].values if pd.notna(s))
    
    if has_processed:
        dataset_viz_df = df
    else:
        viz_pool_df = process_pool_best_per_round(df, original_agg_cols, args.metric)
        viz_pop_df = process_population_latest(df, original_agg_cols)
        
        if viz_pool_df is not None and viz_pop_df is not None:
            dataset_viz_df = pd.concat([viz_pool_df, viz_pop_df], ignore_index=True)
        elif viz_pool_df is not None:
            dataset_viz_df = viz_pool_df
        elif viz_pop_df is not None:
            dataset_viz_df = viz_pop_df
        else:
            dataset_viz_df = df
    
    visualize_by_dataset_sizes(dataset_viz_df, out_dir, args.dpi, eoh_optimal_dataset_df)
    
    print(f"\n✅ {table_name} processing complete! Outputs saved to: {out_dir}")
    
    return combined_df, dataset_viz_df, eoh_optimal_dataset_df


def main():
    args = parse_args()
    
    # Auto-detect EoH optimal table
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    eoh_optimal_table = args.eoh_optimal_table
    if eoh_optimal_table is None or (eoh_optimal_table and not os.path.exists(eoh_optimal_table)):
        baseline_results_table = os.path.join(_script_dir, 'baseline_results', 'eoh_optimal_gap_table.csv')
        if os.path.exists(baseline_results_table):
            eoh_optimal_table = baseline_results_table
            print(f"📊 Auto-detected EoH optimal table: {baseline_results_table}")
        else:
            eoh_optimal_table = None
    
    # 处理 gap table 和 population table
    gap_combined_df = None
    gap_dataset_viz_df = None
    pop_combined_df = None
    pop_dataset_viz_df = None
    eoh_optimal_dataset_df = None
    
    # 处理 gap table
    if args.gap_table:
        gap_combined_df, gap_dataset_viz_df, eoh_optimal_dataset_df = process_single_table(
            args.gap_table, "Gap Table", args, eoh_optimal_table
        )
    
    # 处理 population table（如果提供）
    if args.population_table:
        pop_combined_df, pop_dataset_viz_df, _ = process_single_table(
            args.population_table, "Population Table", args, eoh_optimal_table
        )
    
    # 如果同时有 gap table 和 population table，合并可视化
    if args.gap_table and args.population_table and gap_combined_df is not None and pop_combined_df is not None:
        print(f"\n{'='*60}")
        print(f"📊 Merging visualizations from both tables")
        print(f"{'='*60}")
        
        # 确定输出目录（gap table 的 test 文件夹）
        gap_table_dir = os.path.dirname(os.path.abspath(args.gap_table))
        merged_out_dir = gap_table_dir
        os.makedirs(merged_out_dir, exist_ok=True)
        
        # 合并数据用于可视化
        merged_combined_df = pd.concat([gap_combined_df, pop_combined_df], ignore_index=True)
        merged_dataset_viz_df = pd.concat([gap_dataset_viz_df, pop_dataset_viz_df], ignore_index=True)
        
        # 生成合并的可视化
        print("\n🎨 Generating merged visualizations...")
        visualize_trend_table(merged_combined_df, os.path.join(merged_out_dir, 'trend_visualization.png'), args.dpi)
        visualize_by_dataset_sizes(merged_dataset_viz_df, merged_out_dir, args.dpi, eoh_optimal_dataset_df)
        
        print(f"\n✅ Merged visualizations saved to: {merged_out_dir}")
    
    print(f"\n{'='*60}")
    print(f"✅ All processing complete!")
    print(f"{'='*60}")




if __name__ == '__main__':
    main()

