#!/usr/bin/env python3
"""Process trend gap table: compute best per round, aggregate by size, and visualize.

This script combines:
1. Compute best pool/population per round
2. Aggregate by size bins or dataset groups
3. Visualize trends

Usage:
  python3 process_and_visualize_trends.py \
    --gap_table /path/to/trend_gap_table.csv \
    --out_dir /path/to/output \
    --metric mean \
    --aggregation size  # or 'dataset' or 'dataset_size'
"""

import os
import sys
import argparse
import re
import subprocess
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def parse_args():
    # Default output directory is the same as script directory
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    # Default EoH optimal table path (in baseline_results/)
    _default_eoh_optimal_table = os.path.join(_script_dir, 'baseline_results', 'eoh_optimal_gap_table.csv')
    
    p = argparse.ArgumentParser(
        description="Process trend gap table: compute best per round, aggregate by size, and visualize."
    )
    p.add_argument(
        '--gap_table',
        required=True,
        help='Input trend_gap_table.csv path (required: your method\'s gap table)'
    )
    p.add_argument(
        '--population_table',
        type=str,
        default=None,
        help='Optional: Path to CSV file containing population data (if different from gap_table). If not provided, population trend will not be plotted.'
    )
    p.add_argument(
        '--eoh_optimal_table',
        type=str,
        default=_default_eoh_optimal_table,
        help=f'Path to CSV file containing EoH optimal heuristic results (default: {_default_eoh_optimal_table})'
    )
    p.add_argument(
        '--out_dir',
        type=str,
        default=None,
        help='Output directory (default: same directory as gap_table)'
    )
    p.add_argument(
        '--metric',
        type=str,
        default='mean',
        choices=['mean', 'min', 'median'],
        help='Aggregation metric for finding best solver (default: mean)'
    )
    p.add_argument(
        '--aggregation',
        type=str,
        default='dataset_size_bins',
        choices=['size', 'dataset', 'dataset_size', 'dataset_size_bins'],
        help='Aggregation method: size (by size bins, no dataset), dataset (by dataset only), dataset_size (by dataset and exact size), dataset_size_bins (by dataset and size bins, default)'
    )
    p.add_argument(
        '--dpi',
        type=int,
        default=150,
        help='DPI for output figures (default: 150)'
    )
    return p.parse_args()


# ==================== Step 2: Compute best per round ====================

def extract_pool_round_number(row_name: str) -> int:
    """Extract round number from row name like 'h0', 'h1', etc."""
    if not row_name.startswith('h') or len(row_name) < 2:
        return -1
    match = re.search(r'^h(\d+)$', row_name)
    return int(match.group(1)) if match else -1


def process_pool_best_per_round(df: pd.DataFrame, agg_cols: list, metric: str) -> pd.DataFrame:
    """
    Process pool best per round: for each round and each column, find the best solver.
    
    This version selects the best solver independently for each column (instance-wise),
    ensuring that each column is monotonically decreasing.
    
    Args:
        df: DataFrame with columns including 'solver' and instance columns
        agg_cols: List of columns to process (instance columns like A__A-n32-k5, etc.)
        metric: 'mean', 'min', or 'median' (not used in this version, kept for compatibility)
    
    Returns:
        DataFrame with one row per round (h0, h1, h2, ...), each column has the best value
    """
    # Filter rows that are pool solvers (h0, h1, h2, ...)
    pool_rows = df[df['solver'].str.match(r'^h\d+$', na=False)].copy()
    
    if len(pool_rows) == 0:
        return pd.DataFrame()
    
    # Extract round numbers
    pool_rows['round'] = pool_rows['solver'].apply(extract_pool_round_number)
    
    max_round = pool_rows['round'].max()
    
    best_per_round = {}
    
    for target_round in range(max_round + 1):
        candidates = pool_rows[pool_rows['round'] <= target_round].copy()
        if len(candidates) == 0:
            continue
        
        out_row = {
            'solver': f"h{target_round}",
        }
        
        # For each column (instance), find the best value among candidates
        # This ensures each instance column is monotonically decreasing
        for col in agg_cols:
            best_val = None
            for idx, candidate in candidates.iterrows():
                val = candidate.get(col)
                if pd.notna(val) and val is not None:
                    try:
                        val_float = float(val)
                        if best_val is None or val_float < best_val:
                            best_val = val_float
                    except (ValueError, TypeError):
                        pass
            
            out_row[col] = best_val
        
        best_per_round[target_round] = out_row
    
    return pd.DataFrame(list(best_per_round.values()))


# ==================== Step 3: Aggregate ====================

def extract_size_from_label(label: str) -> int:
    """
    Extract size (number of customers) from CVRP label name.
    
    Examples:
        'A__A-n32-k5' -> 32
        'B__B-n57-k9' -> 57
        'P__P-n101-k4' -> 101
    """
    match = re.search(r'n(\d+)-', label)
    return int(match.group(1)) if match else None


def extract_dataset_type(col_name: str) -> tuple:
    """
    Extract dataset type and size from column name.
    
    Returns: (dataset_type, size) or (dataset_type, None)
    
    Examples:
        'A__A-n32-k5' -> ('A', 32)
        'B__B-n57-k9' -> ('B', 57)
        'E__E-n76-k10' -> ('E', 76)
        'F__F-n72-k4' -> ('F', 72)
        'M__M-n101-k4' -> ('M', 101)
        'P__P-n101-k4' -> ('P', 101)
        'X__X-n101-k4' -> ('X', 101)
    """
    # Format: <dataset>__<instance_name>
    # Support any uppercase letter dataset type (A, B, E, F, M, P, X, etc.)
    match = re.match(r'^([A-Z])__', col_name)
    if match:
        dataset_type = match.group(1)
        size = extract_size_from_label(col_name)
        return (dataset_type, size)
    
    return (None, None)


def bin_of(size: int) -> str:
    """Map size to bin for CVRP (wider bins)."""
    if size <= 40:
        return 'le40'
    if size <= 60:
        return 's40_60'
    if size <= 80:
        return 's60_80'
    return 'gt80'


def aggregate_by_size(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by size bins (wider aggregation, no dataset distinction).
    
    Groups all instances by size bins regardless of dataset:
    - le40: <= 40 customers
    - s40_60: 40-60 customers
    - s60_80: 60-80 customers
    - gt80: > 80 customers
    """
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    # Map columns to bins
    col2bin = {}
    for c in all_cols:
        size = extract_size_from_label(c)
        if size is not None:
            col2bin[c] = bin_of(size)
    
    bins = ['le40', 's40_60', 's60_80', 'gt80']
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        bin2vals = {b: [] for b in bins}
        
        # Aggregate instances by size bins
        for c, b in col2bin.items():
            val = row.get(c)
            if pd.notna(val) and val is not None:
                try:
                    bin2vals[b].append(float(val))
                except (ValueError, TypeError):
                    pass
        
        out_row = {'solver': solver}
        
        # Add size bin averages
        for b in bins:
            vals = bin2vals[b]
            out_row[b] = None if not vals else float(np.mean(vals))
        
        out_rows.append(out_row)
    
    output_cols = ['solver'] + bins
    return pd.DataFrame(out_rows, columns=output_cols)


def aggregate_by_dataset_size_bins(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by dataset and size bins (default, wider aggregation with dataset distinction).
    
    Groups instances by dataset (A, B, E, F, M, P, X, etc.) and size bins:
    - A_le40, A_s40_60, A_s60_80, A_gt80
    - B_le40, B_s40_60, B_s60_80, B_gt80
    - E_le40, E_s40_60, E_s60_80, E_gt80
    - F_le40, F_s40_60, F_s60_80, F_gt80
    - M_le40, M_s40_60, M_s60_80, M_gt80
    - P_le40, P_s40_60, P_s60_80, P_gt80
    - X_le40, X_s40_60, X_s60_80, X_gt80
    - etc.
    """
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    # Group columns by dataset type and size bin
    dataset_bin_groups = {}  # {dataset_type: {bin: [cols]}}
    
    bins = ['le40', 's40_60', 's60_80', 'gt80']
    
    for col in all_cols:
        dataset_type, size = extract_dataset_type(col)
        if dataset_type and size is not None:
            if dataset_type not in dataset_bin_groups:
                dataset_bin_groups[dataset_type] = {b: [] for b in bins}
            
            bin_name = bin_of(size)
            if bin_name in bins:
                dataset_bin_groups[dataset_type][bin_name].append(col)
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # Process each dataset type
        for dataset_type in sorted(dataset_bin_groups.keys()):
            # Process each size bin
            for bin_name in bins:
                cols = dataset_bin_groups[dataset_type][bin_name]
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                # Create column name: <DatasetType>_<bin>
                col_name = f"{dataset_type}_{bin_name}"
                out_row[col_name] = None if not values else float(np.mean(values))
        
        out_rows.append(out_row)
    
    # Build output columns in sorted order
    output_cols = ['solver']
    
    # Dynamically process all dataset types found in the data (sorted alphabetically)
    for dataset_type in sorted(dataset_bin_groups.keys()):
        # Add size bins for this dataset
        for bin_name in bins:
            col_name = f"{dataset_type}_{bin_name}"
            if col_name in out_rows[0] if out_rows else True:
                output_cols.append(col_name)
    
    return pd.DataFrame(out_rows, columns=output_cols)


def aggregate_by_dataset(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by dataset only (widest aggregation).
    
    Groups all instances by dataset (A, B, E, F, M, P, X, etc.) regardless of size.
    """
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    # Group columns by dataset type
    dataset_groups = {}  # {dataset_type: [cols]}
    
    for col in all_cols:
        dataset_type, size = extract_dataset_type(col)
        if dataset_type:
            if dataset_type not in dataset_groups:
                dataset_groups[dataset_type] = []
            dataset_groups[dataset_type].append(col)
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # Process each dataset type
        for dataset_type, cols in dataset_groups.items():
            values = []
            for col in cols:
                val = row.get(col)
                if pd.notna(val) and val is not None:
                    try:
                        values.append(float(val))
                    except (ValueError, TypeError):
                        pass
            
            if values:
                out_row[dataset_type] = float(np.mean(values))
            else:
                out_row[dataset_type] = None
        
        out_rows.append(out_row)
    
    # Build output columns
    output_cols = ['solver'] + sorted(dataset_groups.keys())
    return pd.DataFrame(out_rows, columns=output_cols)


def aggregate_by_dataset_groups(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregate gap table by dataset groups (medium aggregation).
    
    Groups instances by dataset and exact size:
    - A: aggregate by size (32, 33, 34, ...)
    - B: aggregate by size (31, 34, 35, ...)
    - E: aggregate by size (...)
    - F: aggregate by size (...)
    - M: aggregate by size (...)
    - P: aggregate by size (16, 19, 20, ...)
    - X: aggregate by size (...)
    - etc.
    """
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    # Group columns by dataset type and size
    dataset_groups = {}  # {dataset_type: {size: [cols]}}
    
    for col in all_cols:
        dataset_type, size = extract_dataset_type(col)
        if dataset_type:
            if dataset_type not in dataset_groups:
                dataset_groups[dataset_type] = {}
            
            if size is not None:
                if size not in dataset_groups[dataset_type]:
                    dataset_groups[dataset_type][size] = []
                dataset_groups[dataset_type][size].append(col)
    
    out_rows = []
    
    for _, row in df.iterrows():
        solver = row['solver']
        out_row = {'solver': solver}
        
        # Process each dataset type
        for dataset_type, size_groups in dataset_groups.items():
            # Process datasets with size
            for size, cols in size_groups.items():
                values = []
                for col in cols:
                    val = row.get(col)
                    if pd.notna(val) and val is not None:
                        try:
                            values.append(float(val))
                        except (ValueError, TypeError):
                            pass
                
                # Create column name: <DatasetType>_<size>
                col_name = f"{dataset_type}_{size}"
                out_row[col_name] = None if not values else float(np.mean(values))
        
        out_rows.append(out_row)
    
    # Build output columns in sorted order
    output_cols = ['solver']
    
    # Dynamically process all dataset types found in the data (sorted alphabetically)
    for dataset_type in sorted(dataset_groups.keys()):
        # Add datasets with size (sorted by size)
        sizes = sorted([s for s in dataset_groups[dataset_type].keys()])
        for size in sizes:
            output_cols.append(f"{dataset_type}_{size}")
    
    return pd.DataFrame(out_rows, columns=output_cols)


# ==================== Step 4: Visualize ====================

def extract_generation(row_name: str) -> tuple:
    """Extract generation info from row name. Returns (type, gen_num)."""
    # Check for h0, h1, h2 format (pool best)
    if row_name.startswith('h') and len(row_name) > 1:
        match = re.search(r'^h(\d+)$', row_name)
        if match:
            return ('pool', int(match.group(1)))
    return (None, None)


def visualize_trend_table(df: pd.DataFrame, out_path: str, dpi: int = 150):
    """Visualize trend table: pool across dataset average."""
    # Find dataset columns (exclude metadata)
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    all_cols = [c for c in df.columns if c not in metadata_cols]
    
    pool_data = {'generation': [], 'dataset_avg': []}
    
    for _, row in df.iterrows():
        solver = row['solver']
        typ, gen = extract_generation(solver)
        
        if typ is None:
            continue
        
        # Compute dataset average
        dataset_vals = []
        for col in all_cols:
            val = row.get(col)
            if pd.notna(val) and val is not None:
                dataset_vals.append(float(val))
        
        if dataset_vals:
            dataset_avg = np.mean(dataset_vals)
            
            if typ == 'pool':
                pool_data['generation'].append(gen)
                pool_data['dataset_avg'].append(dataset_avg)
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Plot dataset average
    if pool_data['generation']:
        pool_gen, pool_avg = zip(*sorted(zip(pool_data['generation'], pool_data['dataset_avg'])))
        pool_gen = list(pool_gen)
        pool_avg = list(pool_avg)
        ax.plot(pool_gen, pool_avg, 'o-', label='Co-evolution', linewidth=2, markersize=6)
        
        # 标注最终值
        if pool_gen and pool_avg:
            final_gen = pool_gen[-1]
            final_avg = pool_avg[-1]
            ax.annotate(f'{final_avg:.2f}%', 
                        xy=(final_gen, final_avg), 
                        xytext=(5, 5), 
                        textcoords='offset points',
                        fontsize=10, fontweight='bold')
    
    ax.set_xlabel('Round', fontsize=12)
    ax.set_ylabel('Average Gap (%)', fontsize=12)
    ax.set_title('CVRP: Trend of Average Gap Across Datasets', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    
    print(f"   Saved visualization to {out_path}")


def main():
    args = parse_args()
    
    # Set output directory
    if args.out_dir:
        out_dir = os.path.abspath(args.out_dir)
    else:
        out_dir = os.path.dirname(os.path.abspath(args.gap_table))
    os.makedirs(out_dir, exist_ok=True)
    
    # Load gap table
    print(f" Loading gap table from: {args.gap_table}")
    df = pd.read_csv(args.gap_table)
    
    # Determine instance columns
    metadata_cols = ['solver', 'best_solver_idx', 'best_solver_name', 'agg_score', 'generation', 'original_name']
    instance_cols = [c for c in df.columns if c not in metadata_cols]
    
    print(f"  Found {len(instance_cols)} instance columns")
    
    # Step 1: Process pool best per round (instance-wise first)
    print(f"\n Processing pool best per round (instance-wise, metric: {args.metric})...")
    best_per_round_df = process_pool_best_per_round(df, instance_cols, args.metric)
    
    if len(best_per_round_df) > 0:
        best_per_round_path = os.path.join(out_dir, 'pool_best_per_round.csv')
        best_per_round_df.to_csv(best_per_round_path, index=False)
        print(f"   Saved to {best_per_round_path}")
        
        # aggregate on pool best
        print("\n Running aggregate on pool best...")
        _script_dir = os.path.dirname(os.path.abspath(__file__))
        aggregate_script = os.path.join(_script_dir, 'aggregate_comparison_results.py')
        aggregate_output = os.path.join(out_dir, 'pool_best_per_round_aggregated.csv')
        
        try:
            result = subprocess.run(
                [sys.executable, aggregate_script, '--input', best_per_round_path, '--output', aggregate_output],
                capture_output=True,
                text=True,
                check=True
            )
            print(f"   Aggregate completed: {aggregate_output}")
            if result.stdout:
                print(f"  {result.stdout}")
        except subprocess.CalledProcessError as e:
            print(f"   Aggregate failed: {e}")
            if e.stderr:
                print(f"  Error: {e.stderr}")
    else:
        print(f"   No pool data found, skipping pool best per round")
        best_per_round_df = df.copy()
    
    # Step 2: Aggregate the pool best per round results
    print(f"\n Aggregating pool best per round by {args.aggregation}...")
    if args.aggregation == 'size':
        aggregated_df = aggregate_by_size(best_per_round_df)
    elif args.aggregation == 'dataset':
        aggregated_df = aggregate_by_dataset(best_per_round_df)
    elif args.aggregation == 'dataset_size':
        aggregated_df = aggregate_by_dataset_groups(best_per_round_df)
    else:  # dataset_size_bins (default)
        aggregated_df = aggregate_by_dataset_size_bins(best_per_round_df)
    
    aggregated_path = os.path.join(out_dir, f'gap_by_{args.aggregation}.csv')
    aggregated_df.to_csv(aggregated_path, index=False)
    print(f"   Saved to {aggregated_path}")
    
    # Step 3: Visualize
    print(f"\n Visualizing trends...")
    viz_path = os.path.join(out_dir, 'trend_visualization.png')
    visualize_trend_table(aggregated_df, viz_path, args.dpi)
    
    print(f"\n Processing complete! Outputs saved to: {out_dir}")


if __name__ == '__main__':
    main()

