import re
import argparse
from pathlib import Path

import pandas as pd


def load_csvs_with_regex(pattern: str, directory: str = ".") -> pd.DataFrame:
    """
    Find all CSV files in `directory` whose *filenames* match `pattern`
    (a regular expression), load them and concatenate into one DataFrame.
    """
    directory = Path(directory)
    regex = re.compile(pattern)

    csv_paths = sorted(
        p for p in directory.glob("*.csv") if regex.search(p.name)
    )

    if not csv_paths:
        raise FileNotFoundError(
            f"No CSV files in {directory} match regex pattern: {pattern!r}"
        )

    dfs = []
    for path in csv_paths:
        print(f"Loading: {path}")
        df = pd.read_csv(path)
        dfs.append(df)

    combined = pd.concat(dfs, ignore_index=True)
    print(f"Loaded {len(csv_paths)} files, combined shape = {combined.shape}")
    return combined

def add_derived_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add:
    - total_time_needed = sum of the per-run stages:
        compute_graph_points + compute_shortest_paths +
        compute_distance_matrix + sat_based
      (graph_points / shortest_paths treated as optional)
    - matching_length_delta = matching_length_final - matching_length_initial

    NOTE: compute_matching_time_needed is *not* included in total_time_needed,
    because it's already accounted for inside sat_based_time_needed.
    """
    df = df.copy()

    # MUST exist
    mandatory_time_cols = [
        "compute_distance_matrix_time_needed",
        "sat_based_time_needed",  # includes matching time internally
    ]

    # Optional in the total
    optional_time_cols = [
        "compute_shortest_paths_time_needed",
        "compute_graph_points_time_needed",
    ]

    required_other_cols = [
        "matching_length_final",
        "matching_length_initial",
    ]

    # Check required columns
    required_cols = mandatory_time_cols + required_other_cols
    missing_required = [c for c in required_cols if c not in df.columns]
    if missing_required:
        raise KeyError(f"Missing required columns in DataFrame: {missing_required}")

    # Time columns actually used in the total (matching time explicitly excluded)
    time_cols_present = mandatory_time_cols + [
        c for c in optional_time_cols if c in df.columns
    ]

    # Optional: ensure numeric
    df[time_cols_present] = df[time_cols_present].apply(pd.to_numeric, errors="coerce")

    # Sum the available stage times
    df["total_time_needed"] = df[time_cols_present].sum(axis=1, min_count=1)

    # Matching length delta
    df["matching_length_delta"] = (
        df["matching_length_final"] - df["matching_length_initial"]
    )

    return df

def summarize_numeric(df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute mean, median, std, min, max, Q1, Q3, and IQR for all numeric columns.
    """
    num_df = df.select_dtypes(include="number")

    if num_df.empty:
        raise ValueError("No numeric columns found in concatenated DataFrame.")

    stats = pd.DataFrame(index=num_df.columns)

    stats["mean"]   = num_df.mean()
    stats["median"] = num_df.median()
    stats["std"]    = num_df.std(ddof=1)  # sample std

    # 🔹 new: min / max
    stats["min"]    = num_df.min()
    stats["max"]    = num_df.max()

    q1 = num_df.quantile(0.25)
    q3 = num_df.quantile(0.75)

    stats["Q1_25%"] = q1
    stats["Q3_75%"] = q3
    stats["IQR"]    = q3 - q1

    return stats


def main():
    parser = argparse.ArgumentParser(
        description="Summarize numeric columns for all CSV files "
                    "whose filenames match a regex."
    )
    parser.add_argument(
        "--pattern", "-p",
        help=r"Regex to match CSV filenames, e.g. 'FePc-.*percentage0\.001'"
    )
    parser.add_argument(
        "-d", "--directory",
        default=".",
        help="Directory to search for CSV files (default: current directory)."
    )
    parser.add_argument(
        "-o", "--output",
        default=None,
        help="Optional: path to save the stats as CSV."
    )

    args = parser.parse_args()

    combined_df = load_csvs_with_regex(args.pattern, args.directory)

    # 🔹 Add the sum and difference columns here
    combined_df = add_derived_columns(combined_df)

    stats_df = summarize_numeric(combined_df)
    stats_df = stats_df.round(3)
    #print(stats_df)
    wanted_rows = [
        "sat_based_iterations",
        "total_time_needed",
        "matching_length_delta",
    ]
    # Only keep those that actually exist (avoid KeyError if one is missing)
    wanted_rows = [r for r in wanted_rows if r in stats_df.index]
    stats_df = stats_df.loc[wanted_rows]

    preferred_order = ["median", "Q1_25%", "Q3_75%", "min", "max"]
    # Any remaining columns (mean, std, IQR, etc.) go after
    remaining = [c for c in stats_df.columns if c not in preferred_order]
    remaining = []
    stats_df = stats_df[preferred_order + remaining]

    # --- Round everything (e.g. to 3 decimals) ---
    stats_df = stats_df.round(3)

    print("\n=== Summary statistics (selected metrics) ===")
    print(stats_df)

    if args.output:
        stats_df.to_csv(args.output)
        print(f"\nSaved stats to {args.output}")


if __name__ == "__main__":
    main()

