"""
This script aggregates and visualizes results from sequential testing experiments.
It scans per-dataset subfolders for e-process result CSVs
(e_process_info_dataset_{dataset}_{method}_seed_*.csv), extracts summary statistics,
and produces two types of outputs:

1. Summary CSVs with mean ± std for TPR, FPR, FNR, and average stopping time.
2. Wealth trajectory plots (with confidence intervals) for Members vs. Non-members,
   including significance threshold lines.

Plots are generated without legends; a single shared legend file is created once
at the root directory.

Usage:
    python aggregate_results.py --root ./results \
                                [--alpha1 0.05] [--alpha2 0.01] [--png]

Arguments:
    --root   Root directory containing per-dataset subfolders with result CSVs.
    --alpha1 Significance level for dashed black threshold line (default: 0.05).
    --alpha2 Significance level for dash-dot green threshold line (default: 0.01).
    --png    Save figures as PNG instead of PDF (default: PDF).

Outputs are written into each dataset’s folder, with a shared legend saved at root.
"""

import os
import re
import glob
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO
from matplotlib.lines import Line2D

BLANK_SPLIT_RE = re.compile(r"\r?\n\s*\r?\n", flags=re.MULTILINE)

def read_two_part_csv_loose(path):
    """
    Reads your two-part e-process CSV:
      Part 1: single-row meta table
      (blank line)
      Part 2: per-sample series (sample_index, power, wealth, power_fp, wealth_fp)
    Returns (meta_df, series_df). If series missing/empty, returns an empty series_df.
    """
    with open(path, "r", encoding="utf-8") as f:
        text = f.read().strip()
    parts = BLANK_SPLIT_RE.split(text, maxsplit=1)

    if not parts:
        raise ValueError(f"{path}: file is empty.")

    meta_text = parts[0].strip()
    if not meta_text:
        raise ValueError(f"{path}: meta section missing.")
    meta_df = pd.read_csv(StringIO(meta_text))

    if len(parts) == 1 or not parts[1].strip():
        series_df = pd.DataFrame(columns=["sample_index", "power", "wealth", "power_fp", "wealth_fp"])
        return meta_df, series_df

    series_df = pd.read_csv(StringIO(parts[1].strip()))
    return meta_df, series_df

def pad_to(arr, n):
    a = np.asarray(arr, dtype=float).ravel()
    if len(a) >= n:
        return a[:n]
    out = np.full(n, np.nan, dtype=float)
    out[:len(a)] = a
    return out

def mean_std_numeric(df, col):
    vals = pd.to_numeric(df.get(col, pd.Series([], dtype=float)), errors="coerce")
    return float(vals.mean()), float(vals.std(ddof=1))

def aggregate_method(root, dataset_name, method, out_dir, alpha1=0.05, alpha2=0.01, save_pdf=True):
    """
    Looks for: e_process_info_dataset_{dataset_name}_{method}_seed_*.csv
    Produces in out_dir:
      - {dataset}_{method}_summary.csv  with mean±std for TPR, FPR, FNR, avg_stop_time
      - {dataset}_{method}_wealth_plot_with_CI.(pdf/png) using geometric mean & multiplicative CI
    Plots have NO legend (shared legend produced separately).
    """
    pattern = os.path.join(root, f"e_process_info_dataset_{dataset_name}_{method}_seed_*.csv")
    files = sorted(glob.glob(pattern))
    if not files:
        print(f"[WARN] No files found for {method} at: {pattern}")
        return {}

    meta_rows = []
    W_members = []  
    W_nonmem  = []   

    for f in files:
        try:
            meta, series = read_two_part_csv_loose(f)
        except Exception as e:
            print(f"[WARN] Skipping {f}: {e}")
            continue

        for col in ["TPR", "FPR", "FNR", "avg_stop_time", "alpha"]:
            if col not in meta.columns:
                meta[col] = np.nan

        meta_rows.append(meta)

        w  = series["wealth"].to_numpy()    if "wealth"    in series.columns else np.array([])
        wf = series["wealth_fp"].to_numpy() if "wealth_fp" in series.columns else np.array([])
        W_members.append(np.asarray(w, dtype=float))
        W_nonmem.append(np.asarray(wf, dtype=float))

    if not meta_rows:
        print(f"[WARN] No valid meta rows for {method}.")
        return {}

    meta_df = pd.concat(meta_rows, ignore_index=True)

    tpr_mean, tpr_std = mean_std_numeric(meta_df, "TPR")
    fpr_mean, fpr_std = mean_std_numeric(meta_df, "FPR")
    fnr_mean, fnr_std = mean_std_numeric(meta_df, "FNR")
    ast_mean, ast_std = mean_std_numeric(meta_df, "avg_stop_time")

    summary = pd.DataFrame({
        "metric": ["TPR", "FPR", "FNR", "avg_stop_time"],
        "mean":   [tpr_mean, fpr_mean, fnr_mean, ast_mean],
        "std":    [tpr_std,  fpr_std,  fnr_std,  ast_std],
        "n_seeds":[len(meta_df)]*4
    })

    os.makedirs(out_dir, exist_ok=True)
    summary_path = os.path.join(out_dir, f"{dataset_name}_{method}_summary.csv")
    summary.to_csv(summary_path, index=False)
    print(f"[INFO] Saved summary: {summary_path}")

    max_len = max((len(w) for w in W_members), default=0)
    max_len_fp = max((len(w) for w in W_nonmem), default=0)
    n = max(max_len, max_len_fp)

    if n == 0:
        print(f"[WARN] No wealth data to plot for {method}. Skipping plot.")
        return {"summary_csv": summary_path}

    W_mat   = np.vstack([pad_to(w, n)  for w in W_members])
    Wfp_mat = np.vstack([pad_to(wf, n) for wf in W_nonmem])

    eps = 1e-12
    W_mat_safe   = np.where(W_mat  > eps, W_mat,  eps)
    Wfp_mat_safe = np.where(Wfp_mat> eps, Wfp_mat, eps)

    logW   = np.log(W_mat_safe)
    logWfp = np.log(Wfp_mat_safe)

    mean_logW   = np.nanmean(logW,   axis=0)
    std_logW    = np.nanstd(logW,    axis=0)
    mean_logWfp = np.nanmean(logWfp, axis=0)
    std_logWfp  = np.nanstd(logWfp,  axis=0)

    mean_W    = np.exp(mean_logW)
    lower_W   = np.exp(mean_logW - std_logW)
    upper_W   = np.exp(mean_logW + std_logW)

    mean_Wfp  = np.exp(mean_logWfp)
    lower_Wfp = np.exp(mean_logWfp - std_logWfp)
    upper_Wfp = np.exp(mean_logWfp + std_logWfp)

    sns.set(font_scale=1.8, style='whitegrid', rc={"grid.linewidth": 1., "lines.linewidth": 1.5})
    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(n)

    df_plot = pd.DataFrame({
        "Number of Samples": x,
        "Members":     mean_W,
        "Non-members": mean_Wfp
    })
    df_melt = df_plot.melt(
        id_vars="Number of Samples",
        value_vars=["Members", "Non-members"],
        var_name="Group",
        value_name="e-value"
    )

    palette = {"Members": "orangered", "Non-members": "blue"}
    dashes_map = {"Members": (), "Non-members": (4, 2)}

    sns.lineplot(
        data=df_melt,
        x="Number of Samples",
        y="e-value",
        hue="Group",
        style="Group",
        palette=palette,
        dashes=dashes_map,
        linewidth=2,
        legend=False,       
        ax=ax
    )

    ax.fill_between(x, lower_W,  upper_W,  color=palette["Members"],     alpha=0.15)
    ax.fill_between(x, lower_Wfp, upper_Wfp, color=palette["Non-members"], alpha=0.15)

    if alpha1 is not None and alpha1 > 0:
        ax.axhline(y=1/alpha1, color="black", linestyle='--', linewidth=2)
    if alpha2 is not None and alpha2 > 0:
        ax.axhline(y=1/alpha2, color="green", linestyle='-.', linewidth=2)

    ax.set_xlabel("Number of Samples")
    ax.set_ylabel("e-value")
    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0, top=120) 

    fig.tight_layout()
    ext = "pdf" if save_pdf else "png"
    plot_path = os.path.join(out_dir, f"{dataset_name}_{method}_wealth_plot_with_CI.{ext}")
    fig.savefig(plot_path, bbox_inches='tight', pad_inches=0.05, format=ext)
    plt.close(fig)
    print(f"[INFO] Saved plot: {plot_path}")

    return {"summary_csv": summary_path, "plot_path": plot_path}

def save_shared_legend(root, alpha1=0.05, alpha2=0.01, save_pdf=True):
    """
    Creates a standalone legend figure (no data), matching the line styles/colors:
      - Members (orangered, solid)
      - Non-members (blue, dashed)
      - Significance level 5% (black dashed)
      - Significance level 1% (green dash-dot)
    Saves to root as shared_legend.(pdf/png)
    """
    palette = {"Members": "orangered", "Non-members": "blue"}

    h_members = Line2D([0], [0], color=palette["Members"], linewidth=2, linestyle='-', label="Members")
    h_nonmem  = Line2D([0], [0], color=palette["Non-members"], linewidth=2, linestyle=(0, (4, 2)), label="Non-members")

    thresh = []
    if alpha1 is not None and alpha1 > 0:
        thresh.append(Line2D([0], [0], color="black", linewidth=2, linestyle='--', label=f"Significance level {int(alpha1*100)}%"))
    if alpha2 is not None and alpha2 > 0:
        thresh.append(Line2D([0], [0], color="green", linewidth=2, linestyle='-.', label=f"Significance level {int(alpha2*100)}%"))

    handles = [h_members, h_nonmem] + thresh

    fig_legend = plt.figure(figsize=(6.5, 1.0))
    ax = fig_legend.add_subplot(111)
    ax.axis('off')
    legend = fig_legend.legend(handles=handles, loc="center", ncol=len(handles), frameon=False)
    ext = "pdf" if save_pdf else "png"
    out_path = os.path.join(root, f"shared_legend.{ext}")
    fig_legend.savefig(out_path, bbox_inches='tight', pad_inches=0.05, format=ext)
    plt.close(fig_legend)
    print(f"[INFO] Saved shared legend: {out_path}")

def main():
    parser = argparse.ArgumentParser(
        description="Aggregate results for all datasets under a root folder. "
                    "Each subfolder is one dataset; outputs saved into each subfolder. "
                    "Plots have NO legend; a shared legend file is created once in the root."
    )
    parser.add_argument("--root", required=True, help="Root directory that contains per-dataset subfolders.")
    parser.add_argument("--alpha1", type=float, default=0.05, help="Significance level for black dashed line (default: 0.05).")
    parser.add_argument("--alpha2", type=float, default=0.01, help="Significance level for green dash-dot line (default: 0.01).")
    parser.add_argument("--png", action="store_true", help="Save figures as PNG instead of PDF.")
    args = parser.parse_args()

    root = os.path.abspath(args.root)
    save_pdf = not args.png

    print(f"[INFO] Root: {root}")

    subfolders = [d for d in sorted(os.listdir(root)) if os.path.isdir(os.path.join(root, d))]
    if not subfolders:
        print(f"[WARN] No subfolders found in {root}. Nothing to do.")
    for ds in subfolders:
        ds_dir = os.path.join(root, ds)

        dataset_name = ds
        print(f"[INFO] Processing dataset: {dataset_name}  (dir: {ds_dir})")

        any_output = False
        for method in ["ONLINE", "STATIC_M250"]:

            patt = os.path.join(ds_dir, f"e_process_info_dataset_{dataset_name}_{method}_seed_*.csv")
            if glob.glob(patt):
                aggregate_method(
                    ds_dir, dataset_name, method, out_dir=ds_dir,
                    alpha1=args.alpha1, alpha2=args.alpha2, save_pdf=save_pdf
                )
                any_output = True
            else:
                print(f"[WARN] No files for {method} at {patt}")

        if not any_output:
            print(f"[WARN] Skipped {dataset_name}: no matching CSVs for ONLINE/STATIC_M250.")

    save_shared_legend(root, alpha1=args.alpha1, alpha2=args.alpha2, save_pdf=save_pdf)

    print("[DONE] Aggregation complete.")

if __name__ == "__main__":
    main()
