#!/usr/bin/env python3
# fano_ccdf_from_timeseries.py
#
# HOW TO RUN (examples):
#   python fano_ccdf_from_timeseries.py \
#     --config datasets.yaml \
#     --out_dir results_fano \
#     --xlog \
#     --single_series_from PINOT_IP \
#     --single_min_len 100 \
#     --single_min_nz_frac 0.30
#
# Outputs (added):
#   - single_series_acf.csv        : dataset, series_id, lag, acf
#   - acf_single_<NAME>.png/.pdf   : ACF of the chosen single series (lags 0..200)
#
# Existing outputs unchanged:
#   - per_series_fano.csv, fano_ccdf.png/.pdf, mean_acf.csv, acf_mean.png/.pdf
#
from __future__ import annotations

import argparse
import os
from typing import List, Tuple, Dict, Optional
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp

import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import matplotlib.pyplot as plt

# Set font size globally
plt.rcParams.update({'font.size': 12})


# ----------------------------
# Datapath loaders (DeepAR-style)
# ----------------------------
def to_series_list_from_df(df: pd.DataFrame) -> list[tuple[str, np.ndarray]]:
    """Assumes df has columns: series_id, time_idx, value."""
    series_list = []
    for sid, sdf in df.groupby("series_id", sort=False):
        v = sdf.sort_values("time_idx")["value"].to_numpy(dtype=np.float32)
        if np.isfinite(v).all() and len(v) >= 8:
            series_list.append((str(sid), v))
    return series_list


def from_node_features_to_df(table: pa.Table) -> pd.DataFrame:
    """Fallback loader for a Parquet that has 'node_features' column (list-of-lists)."""
    node_features = table.column("node_features")
    rows = []
    for i in range(len(node_features)):
        row = node_features[i]
        n_series = len(row.as_py())
        for j in range(n_series):
            arr = [x for x in row[j].as_py() if x != -1]
            if len(arr) >= 8:
                time_idx = np.arange(len(arr), dtype=np.int64)
                sid = f"{i}_{j}"
                rows.append(pd.DataFrame(
                    {"series_id": sid, "time_idx": time_idx, "value": np.array(arr, dtype=np.float32)}
                ))
    if not rows:
        return pd.DataFrame(columns=["series_id", "time_idx", "value"])
    return pd.concat(rows, ignore_index=True)


def from_inbound_only(table: pa.Table) -> list[tuple[str, np.ndarray]]:
    """
    Loader for timeseries parquet with inbound-only arrays.
    Keys can be:
      - IP                 (columns: ip, source_file, inbound)
      - Subnet             (columns: subnet, source_file, inbound)
      - IP + Service_port  (columns: ip, service_port, source_file, inbound)

    Returns list of (series_id, values).
    """
    cols = set(table.column_names)

    # Figure out which key columns exist
    if {"ip", "service_port", "source_file", "inbound"} <= cols:
        key_cols = ["ip", "service_port", "source_file"]
    elif {"ip", "source_file", "inbound"} <= cols:
        key_cols = ["ip", "source_file"]
    elif {"subnet", "source_file", "inbound"} <= cols:
        key_cols = ["subnet", "source_file"]
    else:
        raise ValueError(f"Table schema not recognized for inbound loader. Columns: {cols}")

    # Extract all key columns and inbound
    key_data = [table.column(c).to_pylist() for c in key_cols]
    inbound_col = table.column("inbound")

    series = []
    n = len(inbound_col)
    for i in range(n):
        inb = inbound_col[i].as_py() if hasattr(inbound_col[i], "as_py") else inbound_col[i]
        if inb is None:
            continue
        # Keep zeros, drop None, drop sentinel -1
        arr = np.array(
            [float(x) for x in inb if x is not None and x != -1],
            dtype=np.float32
        )
        if len(arr) < 8 or not np.isfinite(arr).all():
            continue
        # Build series_id like "ip::src" or "ip::port::src" etc.
        sid = "::".join(str(key_data[j][i]) for j in range(len(key_cols)))
        series.append((sid, arr))

    return series


def load_timeseries_one_path(input_parquet: str) -> list[tuple[str, np.ndarray]]:
    """
    Load time series from one of:
      1) (series_id, time_idx, value)
      2) node_features (list-of-lists)
      3) (ip, source_file, inbound[,outbound])  <-- uses inbound only
    Returns list of (series_id, values).
    """
    table = pq.read_table(input_parquet)
    cols = set(table.column_names)

    if {"inbound"} <= cols:
        return from_inbound_only(table)

    if {"series_id", "time_idx", "value"} <= cols:
        df = table.to_pandas()
        df = df.dropna(subset=["value"])
        df = df[np.isfinite(df["value"])]
        return to_series_list_from_df(df)

    if "node_features" in cols:
        df = from_node_features_to_df(table)
        return to_series_list_from_df(df)

    raise ValueError(
        "Unsupported schema. Expected one of: "
        "(series_id,time_idx,value) OR node_features OR (ip,source_file,inbound[,outbound])."
    )


# ----------------------------
# Metrics helpers
# ----------------------------
def fano_factor(x: np.ndarray, eps: float = 1e-12) -> Optional[float]:
    x = np.asarray(x, dtype=np.float64)
    x = x[np.isfinite(x)]
    if x.size < 2:
        return None
    mu = x.mean()
    if not np.isfinite(mu) or mu <= eps:
        return None
    var = x.var(ddof=0)  # population variance
    if not np.isfinite(var):
        return None
    return float(var / mu)


def ccdf(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    v = np.asarray(values, dtype=np.float64)
    v = v[np.isfinite(v)]
    if v.size == 0:
        return np.array([]), np.array([])
    v.sort()  # ascending
    N = v.size
    y = 1.0 - (np.arange(1, N + 1, dtype=np.float64) / N)
    return v, y


def acf_1d(x: np.ndarray, max_lag: int = 200) -> np.ndarray:
    """
    Autocorrelation function up to max_lag using normalized unbiased estimator:
      r[k] = sum_{t=0}^{N-k-1} (x_t - mu)*(x_{t+k} - mu) / ((N-k)*var)
    Returns array of shape (max_lag+1,) with r[0]=1 when var>0; NaN if not defined.
    """
    x = np.asarray(x, dtype=np.float64)
    x = x[np.isfinite(x)]
    n = x.size
    if n < 2:
        return np.full(max_lag + 1, np.nan, dtype=np.float64)
    mu = x.mean()
    xc = x - mu
    var = np.mean(xc * xc)
    if not np.isfinite(var) or var <= 0:
        return np.full(max_lag + 1, np.nan, dtype=np.float64)

    L = min(max_lag, n - 1)
    r = np.empty(max_lag + 1, dtype=np.float64)
    r[:] = np.nan
    r[0] = 1.0
    for k in range(1, L + 1):
        num = np.dot(xc[:-k], xc[k:])
        den = (n - k) * var
        r[k] = num / den if den > 0 else np.nan
    return r


def golden_ratio_figsize(width, fraction=1.0):
    """Return figure dimensions based on the golden ratio."""
    # The golden ratio is approximately 1.618
    golden_ratio = (1 + 5**0.5) / 2
    
    # Calculate height based on the golden ratio
    height = width / golden_ratio * fraction
    
    return (width, height)

# For a single-column plot
width = 3.5  # inches (updated to 7.2 width with golden ratio)
# fig, ax = plt.subplots(figsize=golden_ratio_figsize(width))

# fig, ax = plt.subplots(figsize=golden_ratio_figsize(width, fraction=0.5))


def subsample_ccdf(xs: np.ndarray, ys: np.ndarray, max_points: int = 30) -> tuple[np.ndarray, np.ndarray]:
    """Subsample CCDF data to max_points if there are more points."""
    if len(xs) <= max_points:
        return xs, ys
    idx = np.linspace(0, len(xs) - 1, max_points, dtype=int)
    return xs[idx], ys[idx]


# ----------------------------
# Worker (parallel per dataset)
# ----------------------------
def load_dataset_worker(args_tuple: tuple) -> tuple[str, list, list, np.ndarray]:
    """
    Returns:
      (dataset_name, per_series_rows, fano_values, mean_acf_vec[0..500])
    """
    name, path, max_series, seed = args_tuple
    try:
        print(f"[INFO] Loading series for dataset '{name}' from {path} ...")
        sid_series = load_timeseries_one_path(path)

        # Filter validity
        sid_series = [(sid, s) for sid, s in sid_series if len(s) >= 8 and np.isfinite(s).all()]

        # Optional subsample series count
        if max_series is not None and len(sid_series) > max_series:
            rng = np.random.default_rng(seed)
            idx = rng.choice(len(sid_series), size=max_series, replace=False)
            sid_series = [sid_series[i] for i in idx]

        # Compute Fano per series and accumulate mean ACF up to lag=500
        per_rows = []
        f_values = []
        max_lag = 500
        acf_sum = np.zeros(max_lag + 1, dtype=np.float64)
        acf_cnt = np.zeros(max_lag + 1, dtype=np.int64)

        for sid, s in sid_series:
            # Fano
            mu = float(np.mean(s)) if len(s) > 0 else np.nan
            var = float(np.var(s)) if len(s) > 1 else np.nan
            fan = fano_factor(s)
            if fan is not None and np.isfinite(fan):
                per_rows.append({
                    "dataset": name, "series_id": sid, "length": len(s),
                    "mean": mu, "var": var, "fano": fan
                })
                f_values.append(fan)

            # ACF: trim to first 1000 points if long
            r = acf_1d(s, max_lag=max_lag)
            # Accumulate only finite entries per lag
            mask = np.isfinite(r)
            acf_sum[mask] += r[mask]
            acf_cnt[mask] += 1

        # Mean ACF per lag where we had at least one series
        mean_acf = np.full(max_lag + 1, np.nan, dtype=np.float64)
        nz = acf_cnt > 0
        mean_acf[nz] = acf_sum[nz] / acf_cnt[nz]

        med = (np.nanmedian(f_values) if f_values else np.nan)
        print(f"[INFO] {name}: series used={len(f_values)} | fano median={med:.4f} | acf N_series={acf_cnt[0]}")

        return (name, per_rows, f_values, mean_acf)

    except Exception as e:
        print(f"[ERROR] Failed to load dataset '{name}': {e}")
        # Return empty but correctly shaped ACF vector
        return (name, [], [], np.full(501, np.nan, dtype=np.float64))


# ----------------------------
# Single-series search & ACF
# ----------------------------
def find_single_series_for_acf(dataset_name: str,
                               dataset_path: str,
                               min_len: int = 100,
                               min_nz_frac: float = 0.30) -> tuple[Optional[str], Optional[np.ndarray]]:
    """
    Find the FIRST series with length > min_len and non-zero fraction >= min_nz_frac.
    Returns (series_id, series_values) or (None, None) if not found.
    """
    sid_series = load_timeseries_one_path(dataset_path)

    for sid, s in sid_series:
        if not np.isfinite(s).all():
            continue
        if len(s) <= min_len:
            continue
        nz_frac = (np.count_nonzero(s) / len(s)) if len(s) > 0 else 0.0
        if nz_frac >= min_nz_frac:
            return sid, s
    return None, None


def find_multiple_series_for_acf(dataset_name: str,
                                dataset_path: str,
                                num_series: int = 20,
                                min_len: int = 100,
                                min_nz_frac: float = 0.30) -> list[tuple[str, np.ndarray]]:
    """
    Find multiple series with length > min_len and non-zero fraction >= min_nz_frac.
    Returns list of (series_id, series_values) tuples, up to num_series.
    """
    sid_series = load_timeseries_one_path(dataset_path)
    valid_series = []

    for sid, s in sid_series:
        if len(valid_series) >= num_series:
            break
        if not np.isfinite(s).all():
            continue
        if len(s) <= min_len:
            continue
        nz_frac = (np.count_nonzero(s) / len(s)) if len(s) > 0 else 0.0
        if nz_frac >= min_nz_frac:
            valid_series.append((sid, s))
    
    return valid_series


# ----------------------------
# Main
# ----------------------------
def main():
    ap = argparse.ArgumentParser(description="Fano CCDF + Mean ACF (lags up to 200), plus single-series ACF for a chosen dataset.")
    ap.add_argument("--config", type=str, default=None, help="YAML with mapping name->path (parquet).")
    ap.add_argument("--dataset", action="append", default=None, help="Add a dataset as NAME=PARQUET_PATH. Repeatable.")
    ap.add_argument("--out_dir", type=str, required=True, help="Output directory.")
    ap.add_argument("--xlog", action="store_true", help="CCDF: log scale on X-axis (Fano).")
    ap.add_argument("--xmin", type=float, default=None)
    ap.add_argument("--xmax", type=float, default=None)
    ap.add_argument("--ymin", type=float, default=0.0)
    ap.add_argument("--ymax", type=float, default=1.0)
    ap.add_argument("--max_series_per_dataset", type=int, default=None, help="Optional cap to speed up.")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--title", type=str, default="CCDF of Fano Factors across Datasets")
    ap.add_argument("--n_processes", type=int, default=None, help="Number of processes (default: CPU count, capped at 8)")

    # NEW: single-series search controls
    ap.add_argument("--single_series_from", type=str, default=None,
                    help="Dataset NAME to search for a single series (e.g., PINOT_IP).")
    ap.add_argument("--single_min_len", type=int, default=100,
                    help="Minimum required length (> this) for the single series.")
    ap.add_argument("--single_min_nz_frac", type=float, default=0.10,
                    help="Minimum fraction of non-zero values (>= this) for the single series.")

    args = ap.parse_args()
    os.makedirs(args.out_dir, exist_ok=True)

    # Set number of processes
    if args.n_processes is None:
        args.n_processes = min(mp.cpu_count(), 8)
    print(f"[INFO] Using {args.n_processes} processes for parallel data loading")

    # Gather dataset specs
    name_to_path: Dict[str, str] = {}

    # Option 1: YAML
    if args.config:
        import yaml
        with open(args.config, "r") as f:
            cfg = yaml.safe_load(f)
        if not isinstance(cfg, dict):
            raise ValueError("YAML must be a mapping of name -> parquet_path")
        name_to_path.update({str(k): str(v) for k, v in cfg.items()})

    # Option 2: CLI --dataset NAME=PATH
    if args.dataset:
        for spec in args.dataset:
            if "=" not in spec:
                raise ValueError("--dataset must be NAME=PATH")
            name, path = spec.split("=", 1)
            name_to_path[str(name)] = str(path)

    if not name_to_path:
        raise ValueError("Provide at least one dataset via --config or --dataset NAME=PATH")

    # Prepare worker args
    worker_args = [
        (name, path, args.max_series_per_dataset, args.seed + i)
        for i, (name, path) in enumerate(name_to_path.items())
    ]

    # Run workers
    all_per_rows = []
    ccdf_data: Dict[str, tuple[np.ndarray, np.ndarray]] = {}
    acf_means: Dict[str, np.ndarray] = {}

    if len(worker_args) == 1 or args.n_processes == 1:
        for wa in worker_args:
            name, per_rows, f_values, mean_acf = load_dataset_worker(wa)
            all_per_rows.extend(per_rows)
            if len(f_values):
                xs, ys = ccdf(np.asarray(f_values, dtype=np.float64))
                ccdf_data[name] = (xs, ys)
            acf_means[name] = mean_acf
    else:
        with ProcessPoolExecutor(max_workers=args.n_processes) as ex:
            fut2name = {ex.submit(load_dataset_worker, wa): wa[0] for wa in worker_args}
            for fut in as_completed(fut2name):
                name = fut2name[fut]
                try:
                    dname, per_rows, f_values, mean_acf = fut.result()
                    all_per_rows.extend(per_rows)
                    if len(f_values):
                        xs, ys = ccdf(np.asarray(f_values, dtype=np.float64))
                        ccdf_data[dname] = (xs, ys)
                    acf_means[dname] = mean_acf
                except Exception as e:
                    print(f"[ERROR] Dataset {name} failed: {e}")

    # --- Save per-series Fano CSV
    df = pd.DataFrame(all_per_rows)
    per_csv = Path(args.out_dir) / "per_series_fano.csv"
    (per_csv).write_text("") if df.empty else df.to_csv(per_csv, index=False)
    print(f"[INFO] Saved per-series Fano CSV to {per_csv}")

    # --- Save mean ACF CSV (long format)
    acf_rows = []
    for name, vec in acf_means.items():
        for lag in range(vec.shape[0]):
            val = float(vec[lag]) if np.isfinite(vec[lag]) else np.nan
            acf_rows.append({"dataset": name, "lag": lag, "mean_acf": val})
    acf_df = pd.DataFrame(acf_rows)
    acf_csv = Path(args.out_dir) / "mean_acf.csv"
    acf_df.to_csv(acf_csv, index=False)
    print(f"[INFO] Saved mean ACF CSV to {acf_csv}")

    plt.figure(figsize=golden_ratio_figsize(width))
    
    markers = ["o", "s", "^", "D", "v", "<", ">", "p", "*", "h", "+", "x"]
    
    # Hardcoded dataset order for consistent markers between plots
    hardcoded_dataset_order = [
        "MAWI_IP", "MAWI_SUBNET", "M_SERV",
        "PINOT_IP", "PINOT_SUBNET", "P_SERV",
        "WEATHER", "TAXI", "ETT", "ELECTRICITY", "EXCHANGE_RATE"
    ]
    
    # Use hardcoded order if datasets exist, otherwise fall back to original order
    available_datasets = list(ccdf_data.keys())
    dataset_order = [name for name in hardcoded_dataset_order if name in available_datasets]
    # Add any additional datasets not in hardcoded list
    dataset_order.extend([name for name in available_datasets if name not in hardcoded_dataset_order])

    # Plot ALL datasets in CCDF
    for i, name in enumerate(dataset_order):
        if name not in ccdf_data:
            continue
        xs, ys = ccdf_data[name]
        if xs.size == 0:
            continue
        xs_plot, ys_plot = subsample_ccdf(xs, ys, max_points=30)
        marker = markers[i % len(markers)]
        color = "red" if ("MAWI" in name.upper() or "PINOT" in name.upper() or "M" in name.upper() or "P_" in name.upper()) else "black"
        plt.plot(xs_plot, ys_plot, marker=marker, markersize=5, label=name, linewidth=2, color=color)

    if args.xlog:
        plt.xscale("log")
    if args.xmin is not None or args.xmax is not None:
        plt.xlim(left=args.xmin, right=args.xmax)
    if args.ymin is not None or args.ymax is not None:
        plt.ylim(bottom=args.ymin, top=args.ymax)

    plt.xlabel("Fano factor (variance/mean)")
    plt.ylabel("CCDF  P(Fano ≥ x)")
    plt.legend(loc='upper right')
    plt.tight_layout(pad=0.2)

    out_png = Path(args.out_dir) / "fano_ccdf.png"
    out_pdf = Path(args.out_dir) / "fano_ccdf.pdf"
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.savefig(out_pdf, dpi=300, bbox_inches="tight")
    print(f"[INFO] Saved plot to {out_png} and {out_pdf}")

    # --- Plot MEAN ACF (with SIGCOMM formatting)
    plt.figure(figsize=golden_ratio_figsize(width))

    # Only plot selected datasets in ACF (subset of all datasets)
    acf_selected_datasets = [
        "M_SERV", "P_SERV", "ELECTRICITY", "TAXI", "WEATHER", "ETT"
    ]
    
    # Use the same dataset order as CCDF plot for consistent markers, but only plot selected ones
    for i, name in enumerate(dataset_order):
        if name not in acf_means or name not in acf_selected_datasets:
            continue
        vec = acf_means[name]
        marker = markers[i % len(markers)]  # Same marker index as in CCDF plot
        color = "red" if ("MAWI" in name.upper() or "PINOT" in name.upper()) else "black"
        
        # Sample lags every 5 windows up to 500
        lag_idx = np.arange(0, min(vec.shape[0], 501), 5)  # [0, 5, 10, ..., 500]
        sampled_vec = vec[lag_idx]
        mask = np.isfinite(sampled_vec)
        
        if not np.any(mask):
            continue
        plt.plot(lag_idx[mask], sampled_vec[mask], marker=marker, markersize=5, linewidth=2, label=name, color=color)

    plt.xlabel("Lag")
    plt.ylabel("Mean autocorrelation")
    plt.xlim(0, 500)
    plt.ylim(-0.5, 1.0)
    plt.legend(ncol=2)
    plt.tight_layout(pad=0.2)

    out_png_acf = Path(args.out_dir) / "acf_mean.png"
    out_pdf_acf = Path(args.out_dir) / "acf_mean.pdf"
    plt.savefig(out_png_acf, dpi=300, bbox_inches="tight")
    plt.savefig(out_pdf_acf, dpi=300, bbox_inches="tight")
    print(f"[INFO] Saved ACF plot to {out_png_acf} and {out_pdf_acf}")

    # --- 20 Single-series ACF plots from a chosen dataset (e.g., PINOT_IP)
    if args.single_series_from:
        name = args.single_series_from
        if name not in name_to_path:
            print(f"[WARN] --single_series_from={name} not found among datasets. Skipping single-series ACF.")
        else:
            path = name_to_path[name]
            series_list = find_multiple_series_for_acf(
                dataset_name=name,
                dataset_path=path,
                num_series=20,
                min_len=args.single_min_len,
                min_nz_frac=args.single_min_nz_frac
            )
            if not series_list:
                print(f"[WARN] No series in '{name}' met length>{args.single_min_len} and non-zero≥{args.single_min_nz_frac:.2f}.")
            else:
                print(f"[INFO] Found {len(series_list)} series from '{name}' for individual ACF plots")
                
                # Save all series ACF data to CSV
                all_rows = []
                
                # Use same marker system as main plots
                # Find the index of this dataset in the dataset_order to get consistent marker
                dataset_index = dataset_order.index(name) if name in dataset_order else 0
                marker = markers[dataset_index % len(markers)]
                color = "red" if ("MAWI" in name.upper() or "PINOT" in name.upper()) else "black"
                
                # Create individual plots for each series
                for plot_idx, (sid, s) in enumerate(series_list):
                    x = s[:1000] if len(s) > 1000 else s
                    r = acf_1d(x, max_lag=100)
                    
                    # Add to CSV data
                    for lag in range(r.shape[0]):
                        all_rows.append({
                            "dataset": name, 
                            "series_id": sid, 
                            "series_number": plot_idx + 1,
                            "lag": lag, 
                            "acf": float(r[lag]) if np.isfinite(r[lag]) else np.nan
                        })
                    
                    # Create individual plot
                    plt.figure(figsize=golden_ratio_figsize(width))
                    
                    lag_idx = np.arange(r.shape[0])
                    mask = np.isfinite(r)
                    plt.plot(lag_idx[mask], r[mask], marker=marker, markersize=5, linewidth=2, color=color,
                             label=f"Example {name}")
                    plt.xlabel("Lag")
                    plt.ylabel("Autocorrelation")
                    plt.xlim(0, 100)
                    plt.ylim(-0.3, 0.5)
                    plt.legend()
                    plt.tight_layout(pad=0.2)
                    
                    # Save individual plot
                    out_single_png = Path(args.out_dir) / f"acf_single_{name}_series_{plot_idx + 1:02d}.png"
                    out_single_pdf = Path(args.out_dir) / f"acf_single_{name}_series_{plot_idx + 1:02d}.pdf"
                    plt.savefig(out_single_png, dpi=300, bbox_inches="tight")
                    plt.savefig(out_single_pdf, dpi=300, bbox_inches="tight")
                    plt.close()  # Close figure to free memory
                    print(f"[INFO] Saved single-series ACF plot {plot_idx + 1}/{len(series_list)} for '{name}' to {out_single_png}")
                
                # Save all ACF data to CSV
                pd.DataFrame(all_rows).to_csv(Path(args.out_dir) / "single_series_acf.csv", index=False)
                print(f"[INFO] Saved ACF data for all {len(series_list)} series to single_series_acf.csv")


if __name__ == "__main__":
    main()
