#!/usr/bin/env python3
import os
import sys
import argparse
import pandas as pd
from pathlib import Path

# 添加项目路径
_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(_script_dir)
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))
_project_root = os.path.dirname(_heupsro_dir)
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

from heupsro.problems.cvrp.testing.process_and_visualize_trends import (
    aggregate_by_dataset,
    aggregate_by_dataset_groups,
    aggregate_by_dataset_size_bins
)


def detect_cvrp_experiment(experiment_name: str, experiments_root: str = None) -> bool:
    if experiments_root:
        exp_dir = os.path.join(experiments_root, experiment_name)
        config_path = os.path.join(exp_dir, "config.json")
        if os.path.exists(config_path):
            import json
            with open(config_path, 'r') as f:
                config_data = json.load(f)
            return 'num_customers' in config_data or 'vehicle_capacity' in config_data
    # 如果无法检测，假设都是 CVRP（因为这是 CVRP 的脚本）
    return True


def aggregate_comparison_results(input_csv: str, output_csv: str, experiments_root: str = None, aggregation: str = None):
    """aggregate comparison results CSV
    
    generate two aggregated results:
    1. aggregated_dataset.csv - aggregate by dataset
    2. aggregated_dataset_size_bins.csv - aggregate by dataset and size bins
    """
    # read CSV
    df = pd.read_csv(input_csv)
    is_baseline_format = 'experiment' not in df.columns and 'solver' in df.columns
    
    if not is_baseline_format and 'experiment' not in df.columns:
        raise ValueError("CSV must contain 'experiment' column or 'solver' column")
    

    # aggregate each row (each experiment/baseline method)
    aggregated_rows_dataset = []
    aggregated_rows_dataset_size_bins = []
    
    for idx, row in df.iterrows():
        if is_baseline_format:
            # baseline format: use solver as experiment name
            experiment_name = row.get('solver', f'solver_{idx}')
        else:
            experiment_name = row.get('experiment', f'experiment_{idx}')
        
        # check if it is a CVRP experiment (baseline data always skip detection)
        if not is_baseline_format and not detect_cvrp_experiment(experiment_name, experiments_root):
            print(f"   {experiment_name}: not a CVRP experiment, skip")
            continue
        
        # extract metadata columns
        if is_baseline_format:
            metadata_cols = ['solver']
            metadata = {'experiment': experiment_name, 'solver': experiment_name}
        else:
            metadata_cols = ['experiment', 'solver', 'best_solver_idx', 'best_solver_name']
            metadata = {col: row.get(col) for col in metadata_cols if col in row}
        
        # create single row DataFrame for aggregation
        data_row = row.drop(labels=[col for col in metadata_cols if col in row.index])
        single_row_df = pd.DataFrame([data_row])
        
        # add solver column (aggregation function needs)
        if 'solver' not in single_row_df.columns:
            solver_name = metadata.get('solver', experiment_name)
            single_row_df['solver'] = solver_name
        
        # use two aggregation methods
        try:
            # 1. aggregate by dataset
            aggregated_df_dataset = aggregate_by_dataset(single_row_df)
            if len(aggregated_df_dataset) > 0:
                aggregated_row_dataset = aggregated_df_dataset.iloc[0].to_dict()
                aggregated_row_dataset.update(metadata)
                aggregated_rows_dataset.append(aggregated_row_dataset)
            
            # 2. aggregate by dataset and size bins
            aggregated_df_size_bins = aggregate_by_dataset_size_bins(single_row_df)
            if len(aggregated_df_size_bins) > 0:
                aggregated_row_size_bins = aggregated_df_size_bins.iloc[0].to_dict()
                aggregated_row_size_bins.update(metadata)
                aggregated_rows_dataset_size_bins.append(aggregated_row_size_bins)
            
            print(f"   {experiment_name}: aggregated (two ways)")
        except Exception as e:
            print(f"   {experiment_name}: aggregation failed - {e}")
            import traceback
            traceback.print_exc()
            continue
    
    if len(aggregated_rows_dataset) == 0 or len(aggregated_rows_dataset_size_bins) == 0:
        print(" no successful aggregation")
        return
    
    # create aggregated DataFrame
    aggregated_df_dataset = pd.DataFrame(aggregated_rows_dataset)
    aggregated_df_size_bins = pd.DataFrame(aggregated_rows_dataset_size_bins)
    
    # reorder columns: experiment, solver, then aggregated data columns
    priority_cols = ['experiment', 'solver']
    
    # reorder dataset aggregated result columns
    cols_dataset = list(aggregated_df_dataset.columns)
    ordered_cols_dataset = [c for c in priority_cols if c in cols_dataset]
    ordered_cols_dataset.extend([c for c in cols_dataset if c not in priority_cols])
    aggregated_df_dataset = aggregated_df_dataset[ordered_cols_dataset]
    
    # reorder dataset_size_bins aggregated result columns
    cols_size_bins = list(aggregated_df_size_bins.columns)
    ordered_cols_size_bins = [c for c in priority_cols if c in cols_size_bins]
    ordered_cols_size_bins.extend([c for c in cols_size_bins if c not in priority_cols])
    aggregated_df_size_bins = aggregated_df_size_bins[ordered_cols_size_bins]
    
    # determine output file path
    input_path = Path(input_csv)
    output_dir = input_path.parent
    
    # save two aggregated results
    # 1. aggregate by dataset
    output_dataset = output_dir / f"{input_path.stem}_aggregated_dataset.csv"
    aggregated_df_dataset.to_csv(output_dataset, index=False)
    print(f"\n aggregated by dataset result saved: {output_dataset}")
    print(f"    contains {len(aggregated_df_dataset)} experiments, {len(aggregated_df_dataset.columns)-len(priority_cols)} aggregated columns")
    
    # 2. aggregate by dataset and size bins
    output_size_bins = output_dir / f"{input_path.stem}_aggregated_dataset_size_bins.csv"
    aggregated_df_size_bins.to_csv(output_size_bins, index=False)
    print(f" aggregated by dataset and size bins result saved: {output_size_bins}")
    print(f"    contains {len(aggregated_df_size_bins)} experiments, {len(aggregated_df_size_bins.columns)-len(priority_cols)} aggregated columns")
    
    # if output parameter is specified, also save to there (using dataset_size_bins as default)
    if output_csv and output_csv != str(output_dataset) and output_csv != str(output_size_bins):
        aggregated_df_size_bins.to_csv(output_csv, index=False)
        print(f" default aggregated result saved: {output_csv}")


def main():
    parser = argparse.ArgumentParser(
        description="aggregate CVRP experiment comparison results (generate two aggregated results: by dataset and by dataset and size bins)"
    )
    parser.add_argument('--input', '-i', required=True, help='input CSV file path')
    parser.add_argument('--output', '-o', default=None, 
                       help='output CSV file path (optional, default will generate two files: *_aggregated_dataset.csv and *_aggregated_dataset_size_bins.csv)')
    parser.add_argument('--experiments-root', default=None,
                       help='experiment root directory (optional, for detecting experiment type)')
    parser.add_argument('--aggregation', default=None,
                       help='deprecated: now always generate two aggregated results')
    
    args = parser.parse_args()
    
    aggregate_comparison_results(args.input, args.output, args.experiments_root, args.aggregation)


if __name__ == '__main__':
    main()

