#!/usr/bin/env python3
"""
process compare_experiments_best.py or find_h_best.py output CSV (TSP experiment),
use TSP aggregate_by_dataset_groups function to aggregate each experiment.

  python heupsro/problems/tsp_gls/testing/aggregate_comparison_results.py \
    --input experiments_comparison_best.csv \
    --output experiments_comparison_aggregated.csv \
    --experiments-root /path/to/experiments 
"""

import os
import sys
import argparse
import pandas as pd
from pathlib import Path

_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(_script_dir)
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))
_project_root = os.path.dirname(_heupsro_dir)
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

from heupsro.problems.tsp_gls.testing.process_and_visualize_trends import aggregate_by_dataset_groups_with_custom_bins


def detect_tsp_experiment(experiment_name: str, experiments_root: str = None) -> bool:
    if experiments_root:
        exp_dir = os.path.join(experiments_root, experiment_name)
        config_path = os.path.join(exp_dir, "config.json")
        if os.path.exists(config_path):
            import json
            with open(config_path, 'r') as f:
                config_data = json.load(f)
            return 'n_cities' in config_data
    return True


def aggregate_comparison_results(input_csv: str, output_csv: str, experiments_root: str = None):
    df = pd.read_csv(input_csv)
    
    has_experiment_col = 'experiment' in df.columns
    if not has_experiment_col:
        print("  CSV file has no 'experiment' column, will use 'solver' column or row index as identifier")
    
    print(f"  processing {len(df)} experiments...")
    
    # aggregate each row (each experiment)
    aggregated_rows = []
    
    for idx, row in df.iterrows():
        # 优先使用 experiment 列，否则使用 solver 列，最后使用行索引
        if has_experiment_col:
            experiment_name = row.get('experiment', f'experiment_{idx}')
        elif 'solver' in row:
            experiment_name = row.get('solver', f'experiment_{idx}')
        else:
            experiment_name = f'experiment_{idx}'
        
        if experiments_root and not detect_tsp_experiment(experiment_name, experiments_root):
            print(f"  {experiment_name}: not a TSP experiment, skip")
            continue
        
        metadata_cols = ['experiment', 'solver', 'best_solver_idx', 'best_solver_name']
        metadata = {col: row.get(col) for col in metadata_cols if col in row.index}
        
        if not has_experiment_col and 'experiment' not in metadata:
            metadata['experiment'] = experiment_name
        
        data_row = row.drop(labels=[col for col in metadata_cols if col in row.index])
        single_row_df = pd.DataFrame([data_row])
        
        if 'solver' not in single_row_df.columns:
            solver_name = metadata.get('solver', experiment_name)
            single_row_df['solver'] = solver_name
        
        try:
            aggregated_df = aggregate_by_dataset_groups_with_custom_bins(single_row_df)
            if len(aggregated_df) > 0:
                aggregated_row = aggregated_df.iloc[0].to_dict()
                aggregated_row.update(metadata)
                aggregated_rows.append(aggregated_row)
                print(f"  {experiment_name}: aggregated")
            else:
                print(f"  {experiment_name}: empty aggregated result")
        except Exception as e:
            print(f"  {experiment_name}: aggregation failed - {e}")
            continue
    
    if len(aggregated_rows) == 0:
        print("  no successful aggregation")
        return
    
    # create aggregated DataFrame
    aggregated_df = pd.DataFrame(aggregated_rows)
    
    # reorder columns: experiment, solver, then aggregated data columns
    cols = list(aggregated_df.columns)
    priority_cols = ['experiment', 'solver']
    # only include actual columns
    ordered_cols = [c for c in priority_cols if c in cols]
    ordered_cols.extend([c for c in cols if c not in priority_cols])
    aggregated_df = aggregated_df[ordered_cols]
    
    aggregated_df.to_csv(output_csv, index=False)
    print(f"\n  aggregated result saved: {output_csv}")
    print(f"    contains {len(aggregated_df)} experiments, {len(aggregated_df.columns)-len([c for c in priority_cols if c in cols])} aggregated columns")


def main():
    parser = argparse.ArgumentParser(
        description="aggregate TSP experiment comparison results (by dataset type and size)"
    )
    parser.add_argument('--input', '-i', required=True, help='input CSV file path')
    parser.add_argument('--output', '-o', default=None, 
                       help='output CSV file path (default: input file same directory as input file name_aggregated.csv)')
    parser.add_argument('--experiments-root', default=None,
                       help='experiment root directory (optional, for detecting experiment type)')
    
    args = parser.parse_args()
    
    if args.output is None:
        input_path = Path(args.input)
        output_path = input_path.parent / f"{input_path.stem}_aggregated.csv"
        args.output = str(output_path)
    
    aggregate_comparison_results(args.input, args.output, args.experiments_root)


if __name__ == '__main__':
    main()

