"""
Combine outer-fold CSVs for a dataset into one aggregate CSV.

Supports multiple path and filename patterns:
- LRT format: results/LRT_depth{D}/{dataset}/results_linear_optuna_d{D}_outer{outer}.csv
- CR format: results/CR_depth5/{dataset}/results_CR_d5_outer{outer}.csv

Example:
    # Combine CR format for energe_c (depth 5, single depth)
    python3 script/utils/combine_one_dataset.py energe_c --path-pattern CR_depth5 --filename-pattern results_CR_d5_outer --output-pattern results_CR_d5_aggregate --min-depth 5 --max-depth 5
"""
import os
import argparse
import pandas as pd


def combine_one_depth(
    dataset_name: str,
    depth: int,
    n_outers: int = 5,
    root: str = "results",
    path_pattern: str = "LRT_depth{depth}",
    filename_pattern: str = "results_linear_optuna_d{depth}_outer",
    output_pattern: str = "results_linear_optuna_d{depth}_aggregate",
) -> None:
    """Combine outer CSVs for a single depth into an aggregate CSV."""
    # Build path: replace {depth} in path_pattern if present
    if "{depth}" in path_pattern:
        root_dir = os.path.join(root, path_pattern.format(depth=depth))
    else:
        root_dir = os.path.join(root, path_pattern)
    dataset_dir = os.path.join(root_dir, dataset_name)

    if not os.path.isdir(dataset_dir):
        raise FileNotFoundError(f"Dataset folder not found: {dataset_dir}")

    all_dfs = []

    # 1. Read outer CSVs for this dataset / depth
    for outer in range(n_outers):
        # Build filename: replace {depth} in filename_pattern if present
        if "{depth}" in filename_pattern:
            base_filename = filename_pattern.format(depth=depth)
        else:
            base_filename = filename_pattern
        csv_filename = f"{base_filename}{outer}.csv"
        csv_path = os.path.join(dataset_dir, csv_filename)
        
        if not os.path.exists(csv_path):
            print(f"WARNING: file not found, skip: {csv_path}")
            continue

        df = pd.read_csv(csv_path)

        # set/overwrite dataset and outer columns (no insert -> no error)
        df["dataset"] = dataset_name
        df["outer"] = f"outer_{outer}"

        all_dfs.append(df)

    if not all_dfs:
        raise RuntimeError(
            f"No CSV files found for dataset {dataset_name} at depth {depth}"
        )

    raw_df = pd.concat(all_dfs, ignore_index=True)

    # desired column order (if some are missing, ignore them)
    desired_cols = [
        "dataset",
        "outer",
        "method",
        "depth",
        "lambda",
        "ridge_penalty",
        "lasso_penalty",
        "leaves",
        "r2_train",
        "r2_test",
        "mse_train",
        "mse_test",
        "train_time_s",
    ]
    other_cols = [c for c in raw_df.columns if c not in desired_cols]
    raw_df = raw_df[[c for c in desired_cols if c in raw_df.columns] + other_cols]

    # 2. Compute mean & std over outers (group by dataset, method, depth, lambda)
    group_cols = ["dataset", "method", "depth", "lambda"]

    numeric_cols = [
        c
        for c in raw_df.columns
        if c not in group_cols + ["outer"]
        and pd.api.types.is_numeric_dtype(raw_df[c])
    ]

    # mean rows
    mean_df = (
        raw_df.groupby(group_cols, dropna=False)[numeric_cols].mean().reset_index()
    )
    mean_df["outer"] = "mean"

    # std rows
    std_df = (
        raw_df.groupby(group_cols, dropna=False)[numeric_cols]
        .std(ddof=1)
        .reset_index()
    )
    std_df["outer"] = "std"

    # make sure column order matches raw_df
    mean_df = mean_df[raw_df.columns.intersection(mean_df.columns).tolist()]
    std_df = std_df[raw_df.columns.intersection(std_df.columns).tolist()]

    # 3. Save aggregate file for this dataset / depth
    final_df = pd.concat([raw_df, mean_df, std_df], ignore_index=True)

    # Build output filename: replace {depth} in output_pattern if present
    if "{depth}" in output_pattern:
        output_filename = f"{output_pattern.format(depth=depth)}.csv"
    else:
        output_filename = f"{output_pattern}.csv"
    out_path = os.path.join(dataset_dir, output_filename)
    final_df.to_csv(out_path, index=False)
    print(f"[depth={depth}] Saved aggregated results to:\n  {out_path}")


def main(
    dataset_name: str,
    min_depth: int,
    max_depth: int,
    n_outers: int = 5,
    path_pattern: str = "LRT_depth{depth}",
    filename_pattern: str = "results_linear_optuna_d{depth}_outer",
    output_pattern: str = "results_linear_optuna_d{depth}_aggregate",
) -> None:
    """Combine results for depths in [min_depth, max_depth] for one dataset."""
    for depth in range(min_depth, max_depth + 1):
        try:
            combine_one_depth(
                dataset_name,
                depth,
                n_outers=n_outers,
                path_pattern=path_pattern,
                filename_pattern=filename_pattern,
                output_pattern=output_pattern,
            )
        except FileNotFoundError as e:
            print(f"[depth={depth}] SKIP (missing folder): {e}")
        except RuntimeError as e:
            print(f"[depth={depth}] SKIP (no CSVs): {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dataset",
        help=(
            "Dataset name under results/{path_pattern} "
            "(e.g. california_housing, electricity, energe_c)"
        ),
    )
    parser.add_argument(
        "--n-outers",
        type=int,
        default=5,
        help="Number of outer splits (default: 5 -> outer0..outer4)",
    )
    parser.add_argument(
        "--min-depth",
        type=int,
        default=4,
        help="Minimum depth (inclusive, default: 4)",
    )
    parser.add_argument(
        "--max-depth",
        type=int,
        default=4,
        help="Maximum depth (inclusive, default: 4)",
    )
    parser.add_argument(
        "--path-pattern",
        type=str,
        default="LRT_depth{depth}",
        help=(
            "Path pattern under results/ directory. Use {depth} placeholder. "
            "Default: LRT_depth{depth}. Example: CR_depth5"
        ),
    )
    parser.add_argument(
        "--filename-pattern",
        type=str,
        default="results_linear_optuna_d{depth}_outer",
        help=(
            "Filename pattern for input CSVs. Use {depth} placeholder. "
            "Default: results_linear_optuna_d{depth}_outer. "
            "Example: results_CR_d5_outer"
        ),
    )
    parser.add_argument(
        "--output-pattern",
        type=str,
        default="results_linear_optuna_d{depth}_aggregate",
        help=(
            "Filename pattern for output CSV. Use {depth} placeholder. "
            "Default: results_linear_optuna_d{depth}_aggregate. "
            "Example: results_CR_d5_aggregate"
        ),
    )
    args = parser.parse_args()

    main(
        args.dataset,
        args.min_depth,
        args.max_depth,
        args.n_outers,
        args.path_pattern,
        args.filename_pattern,
        args.output_pattern,
    )
