#!/usr/bin/env python3

"""
Usage:
python collect_summary.py -d <base_dir> -s <sort_by> -a
python collect_summary.py -d <base_dir> -s <sort_by> -a -l
python collect_summary.py -d <base_dir> -s <sort_by> -a -l -o <output_file>
python collect_summary.py -d <base_dir> -s <sort_by> -a -l -o <output_file> -p <plot_file>
"""

import argparse
import csv
from pathlib import Path
from typing import Optional


def find_summary_files(base_dir: Path) -> list[tuple[str, Path]]:
    results = []
    
    for item in sorted(base_dir.iterdir()):
        if item.is_dir() and item.name.startswith("iteration_"):
            for summary_file in item.glob("eval_validation/**/summary/summary_*.csv"):
                results.append((item.name, summary_file))
    
    final_test_dir = base_dir / "final_test"
    if final_test_dir.exists():
        for summary_file in final_test_dir.glob("**/summary/summary_*.csv"):
            results.append(("final_test", summary_file))
    
    return results


def parse_summary_csv(file_path: Path) -> dict[str, float]:
    metrics = {}
    
    with open(file_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        rows = list(reader)
        
        if not rows:
            return metrics
        
        fieldnames = reader.fieldnames
        if not fieldnames:
            return metrics
        value_col = fieldnames[-1]
        
        for row in rows:
            metric_name = row.get("metric", "")
            value_str = row.get(value_col, "")
            if metric_name and value_str:
                try:
                    metrics[metric_name] = float(value_str)
                except ValueError:
                    pass
    
    return metrics


def collect_and_sort(
    base_dir: str,
    sort_by: Optional[str] = None,
    ascending: bool = False
) -> list[dict]:
    
    base_path = Path(base_dir)
    summary_files = find_summary_files(base_path)
    
    if not summary_files:
        print("not found any summary files.")
        return []
    
    records = []
    for iteration_name, file_path in summary_files:
        try:
            metrics = parse_summary_csv(file_path)
            record = {
                "iteration": iteration_name,
                "file_path": str(file_path),
                **metrics
            }
            records.append(record)
        except Exception as e:
            print(f"Failed to parse file {file_path}: {e}")
    
    if sort_by is not None:
        records.sort(key=lambda x: x.get(sort_by, 0), reverse=not ascending)
    
    return records


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Traverse all iteration_x directories and sort summary CSV files by metric"
    )
    parser.add_argument(
        "-d", "--dir",
        type=str,
        default=".",
        help="Base directory to scan (default: current directory)"
    )
    parser.add_argument(
        "-s", "--sort-by",
        type=str,
        default=None,
        help="Metric name to sort by (default: use the first found metric)"
    )
    parser.add_argument(
        "-a", "--ascending",
        action="store_true",
        help="Sort in ascending order (default: descending)"
    )
    parser.add_argument(
        "-l", "--list-keys",
        action="store_true",
        help="List all available metric names and exit"
    )
    return parser.parse_args()


def get_all_metric_keys(records: list[dict]) -> list[str]:
    all_keys = set()
    for record in records:
        for key in record.keys():
            if key not in ("iteration", "file_path"):
                all_keys.add(key)
    # Sort alphabetically
    return sorted(all_keys)


def main():
    args = parse_args()
    base_dir = Path(args.dir).resolve()
    ascending = args.ascending
    
    records = collect_and_sort(str(base_dir), sort_by=None, ascending=ascending)
    
    if not records:
        return
    
    metric_keys = get_all_metric_keys(records)
    
    if not metric_keys:
        print("No metric data found")
        return
    
    # Handle --list-keys option
    if args.list_keys:
        print("Available metric names:")
        for key in metric_keys:
            print(f"  - {key}")
        return
    
    # Determine sort key
    sort_by = args.sort_by
    if sort_by is None:
        sort_by = metric_keys[0]  # Use the first found metric
    elif sort_by not in metric_keys:
        print(f"Warning: Metric '{sort_by}' not found. Available metrics: {', '.join(metric_keys)}")
        print(f"Using default metric: {metric_keys[0]}")
        sort_by = metric_keys[0]
    
    # Sort by the specified metric
    records.sort(key=lambda x: x.get(sort_by, 0), reverse=not ascending)
    
    print("=" * 80)
    print("Traverse all iteration_x directories and summary files")
    print(f"Directory: {base_dir}")
    print(f"Sort by: {sort_by} ({'Ascending' if ascending else 'Descending'})")
    print("=" * 80)
    
    # Display sorted results
    order_desc = "Ascending" if ascending else "Descending"
    print(f"\nResults sorted by {sort_by} ({order_desc}):\n")
    
    # Dynamically calculate column widths
    col_widths = {}
    for key in metric_keys:
        col_widths[key] = max(len(key), 10)  
    
    header_parts = [f"{'iteration':<15}"]
    for key in metric_keys:
        header_parts.append(f"{key:>{col_widths[key]}}")
    header = " ".join(header_parts)
    print(header)
    print("-" * len(header))
    
    for record in records:
        iteration = record.get("iteration", "N/A")
        row_parts = [f"{iteration:<15}"]
        for key in metric_keys:
            value = record.get(key, 0)
            row_parts.append(f"{value:>{col_widths[key]}.2f}")
        print(" ".join(row_parts))
    
    print("\n" + "=" * 80)
    best_row = records[0]
    print(f"Best iteration (by {sort_by}): {best_row['iteration']}")
    for key in metric_keys:
        value = best_row.get(key, "N/A")
        print(f"  - {key}: {value}")
    print(f"  - File path: {best_row['file_path']}")
    print("=" * 80)
    
    # Save full results to CSV
    output_file = base_dir / "summary_comparison.csv"
    with open(output_file, "w", encoding="utf-8", newline="") as f:
        if records:
            fieldnames = ["iteration"] + metric_keys + ["file_path"]
            writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
            writer.writeheader()
            writer.writerows(records)
    
    print(f"\nFull results saved to: {output_file}")


if __name__ == "__main__":
    main()
