import argparse
import pickle
import numpy as np
import os
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from scipy.stats import wasserstein_distance as WD

RELEVANT_IPs = None #"/PATH_TO_PKL"
def merge_using_list(iei_forecast, compressed_forecast):
    ls = []
    for i in range(len(iei_forecast)):
        ls.extend([
            np.array([compressed_forecast[i]]),
            np.zeros(max(int(iei_forecast[i]), 1) - 1)
        ])
    ls.append(np.array([compressed_forecast[-1]]))
    return ls

def compute_mase(forecast, actual, eps=1e-8):
    mae = np.mean(np.abs(forecast - actual))
    if len(actual) < 2:
        return 0
    denom = np.array(actual)+eps
    return np.mean(mae / denom) 

def process_key(k, iei_pair, cmp_pair, threshold=0.0):
    """Compute all metrics for a single key; return a tuple of results."""

    iei_fc, iei_gt = np.array(iei_pair[0]), np.array(iei_pair[1])
    # clip the np array to max values 1.1e6
    #Normalize to MBs
    cmp_fc, cmp_gt = np.clip(np.array(cmp_pair[0])/1000000, 0, 1), np.clip(np.array(cmp_pair[1])/1000000, 0, 1)
    # merged series
    iei_fc = iei_fc[:2]
    iei_gt = iei_gt[:2]
    cmp_fc = cmp_fc[:2]
    cmp_gt = cmp_gt[:2]
    merged_fc = np.concatenate(merge_using_list(iei_fc, cmp_fc), axis=0)
    chronos_oracle_iei = np.concatenate(merge_using_list(iei_gt[:], cmp_fc), axis=0)
    chronos_oracle_cmp = np.concatenate(merge_using_list(iei_fc, cmp_gt[:]), axis=0)
    merged_gt = np.concatenate(merge_using_list(iei_gt[:], cmp_gt[:]), axis=0)
    merged_gt_orig = merged_gt.copy()
    original_sample =  np.concatenate(merge_using_list(iei_gt[:], cmp_gt[:]), axis=0)
    L1 = min(len(merged_fc), len(merged_gt))
    merged_fc, merged_gt = merged_fc, merged_gt

    # IEI-only
    yhat_iei, y_iei = iei_fc, iei_gt[:]
    L2 = min(len(yhat_iei), len(y_iei))
    yhat_iei, y_iei = yhat_iei[:L2], y_iei[:L2]
    mae_iei  = np.mean(np.abs(yhat_iei - y_iei))
    mase_iei = compute_mase(yhat_iei, y_iei)
    wd_iei   = WD(yhat_iei, y_iei)

    # Compressed-only
    yhat_cmp, y_cmp = cmp_fc, cmp_gt[:]
    L3 = min(len(yhat_cmp), len(y_cmp))
    yhat_cmp, y_cmp = yhat_cmp[:L3], y_cmp[:L3]
    mae_cmp  = np.mean(np.abs(yhat_cmp - y_cmp))
    mase_cmp = compute_mase(yhat_cmp, y_cmp)
    wd_cmp   = WD(yhat_cmp, y_cmp)

    # Merged series
    series_mae  = np.mean(np.abs(merged_fc[:L1] - merged_gt[:L1]))
    L = min(len(chronos_oracle_iei), len(merged_gt))
    oracle_iei_chronos_mae  = np.mean(np.abs(chronos_oracle_iei[:L] - merged_gt[:L])) if L > 0 else np.nan
    oracle_iei_chronos_mase = compute_mase(chronos_oracle_iei[:L], merged_gt[:L]) if L > 0 else np.nan
    oracle_iei_chronos_wd   = WD(chronos_oracle_iei, merged_gt) if L > 0 else np.nan

    L = min(len(chronos_oracle_cmp), len(merged_gt))
    oracle_cmp_chronos_mae  = np.mean(np.abs(chronos_oracle_cmp[:L] - merged_gt[:L])) if L > 0 else np.nan
    oracle_cmp_chronos_mase = compute_mase(chronos_oracle_cmp[:L], merged_gt[:L]) if L > 0 else np.nan
    oracle_cmp_chronos_wd   = WD(chronos_oracle_cmp[:L], merged_gt[:L]) if L > 0 else np.nan

    series_mase = compute_mase(merged_fc[:L1], merged_gt[:L1])
    series_wd   = WD(merged_fc, merged_gt)

    yhat_iei, y_iei = iei_fc[:L2], iei_gt[:L2]
    ae_iei  = np.abs(yhat_iei - y_iei)

    # Compressed-only
    yhat_cmp, y_cmp = cmp_fc[:L3], cmp_gt[:L3]
    ae_cmp   = np.abs(yhat_cmp - y_cmp)

    # Merged
    ae_series = np.abs(merged_fc[:L1] - merged_gt[:L1])

    return {
        "key": k,
        "mae_iei": mae_iei, "mase_iei": mase_iei, "wd_iei": wd_iei,
        "mae_cmp": mae_cmp, "mase_cmp": mase_cmp, "wd_cmp": wd_cmp,
        "series_mae": series_mae, "series_mase": series_mase, "series_wd": series_wd,
        "original_sample": original_sample,
        "oracle_iei_chronos_wd": oracle_iei_chronos_wd,
        "oracle_cmp_chronos_wd": oracle_cmp_chronos_wd,
        # NEW: store raw absolute errors
        "ae_iei": ae_iei,
        "ae_cmp": ae_cmp,
        "ae_series": ae_series,
    }

import os, glob, pickle

def load_dir_map(dir_path: str):
    """
    Load and merge all pickle files in a directory into a single dict.
    Later files simply overwrite keys from earlier ones.
    """
    if not os.path.isdir(dir_path):
        raise ValueError(f"Expected a directory, got {dir_path}")
    merged = {}
    paths = sorted(
        glob.glob(os.path.join(dir_path, "*.pkl")) +
        glob.glob(os.path.join(dir_path, "*.pickle"))
    )
    for p in paths:
        with open(p, "rb") as f:
            d = pickle.load(f)
        if not isinstance(d, dict):
            raise TypeError(f"File {p} did not contain a dict")
        merged.update(d)
    return merged

def merge_timeseries(iei_path, compressed_path, output_path, n_jobs=None, threshold=0.0):
    # load pickles
    iei_map = load_dir_map(iei_path)          # pass directory instead of file
    cmp_map = load_dir_map(compressed_path)
    if RELEVANT_IPs is not None:
        with open(RELEVANT_IPs, "rb") as f:
            good_keys = pickle.load(f)
    else:
        good_keys = None
    #if length of either value in the map is 0, remove that key
    keys_to_remove = set()
    for k in iei_map.keys() & cmp_map.keys():
        if good_keys is not None and k not in set(good_keys):
            keys_to_remove.add(k)
            continue
        iei_val = iei_map[k]
        cmp_val = cmp_map[k]
        if len(iei_val[0]) == 0 or len(iei_val[1]) == 0 or len(cmp_val[0]) == 0 or len(cmp_val[1]) == 0:
            keys_to_remove.add(k)
    for k in keys_to_remove:
        iei_map.pop(k, None)
        cmp_map.pop(k, None)

    merged_items = [
        (k, iei_map[k], cmp_map[k])
        for k in iei_map.keys() & cmp_map.keys()
    ]
    print(f"Processing {len(merged_items)} keys...")

    # parallel execution
    results = []
    with ProcessPoolExecutor(max_workers=n_jobs) as exe:
        futures = [
            exe.submit(process_key, k, iei_pair, cmp_pair, threshold)
            for k, iei_pair, cmp_pair in merged_items
        ]
        for fut in as_completed(futures):
            res = fut.result()
            if res is not None:
                results.append(res)

    #filter results where series_mae is >1e6
    results = [r for r in results if r['series_mae']]
    
    # aggregate
    def agg(name): return np.nanmean([r[name] for r in results])

    # save original samples as pickle
    original_samples = [r['original_sample'] for r in results]
    output_samples_path = os.path.join(output_path, 'original_samples.pkl')
    os.makedirs(output_path, exist_ok=True)
    with open(output_samples_path, 'wb') as f:
        pickle.dump(original_samples, f)

    def agg(name): return np.nanmean([r[name] for r in results])

    print("==== AVERAGES ====")
    print(f"IEI   MAE {agg('mae_iei'):.5f}, MASE {agg('mase_iei'):.5f}, WD {agg('wd_iei'):.5f}")
    print(f"CMP   MAE {agg('mae_cmp'):.5f}, MASE {agg('mase_cmp'):.5f}, WD {agg('wd_cmp'):.5f}")
    print(f"SER   MAE {agg('series_mae'):.5f}, MASE {agg('mase_cmp'):.5f}, WD {agg('series_wd'):.5f}")
    print(f"ORCL_IEI  WD {agg('oracle_iei_chronos_wd'):.5f}")
    print(f"ORCL_CMP WD {agg('oracle_cmp_chronos_wd'):.5f}")
    # -------- Plot MASE across examples (per-key), as CDF and CCDF --------
    def get_valid(name):
        arr = np.array([r[name] for r in results], dtype=float)
        return arr[np.isfinite(arr)]

    mase_series = get_valid("series_mase")  # merged-series MASE per example
    mase_iei    = get_valid("mase_iei")     # IEI MASE per example
    mase_cmp    = get_valid("mase_cmp")     # CMP MASE per example

    # Save raw per-example MASE values (optional, useful for analysis later)
    with open(os.path.join(output_path, "per_example_mase.pkl"), "wb") as f:
        pickle.dump(
            {"series_mase": mase_series, "mase_iei": mase_iei, "mase_cmp": mase_cmp}, f
        )

    # Collect per-example MASE and WD for series and oracles
    series_wd_arr   = get_valid("series_wd")
    oracle_iei_wd   = get_valid("oracle_iei_chronos_wd")
    oracle_cmp_wd   = get_valid("oracle_cmp_chronos_wd")

    # Save per-example metrics
    oracle_metrics_path = os.path.join(output_path, "per_example_oracle_metrics.pkl")
    with open(oracle_metrics_path, "wb") as f:
        pickle.dump(
            {
                "series_mase": mase_series,
                "series_wd": series_wd_arr,
                "oracle_iei_wd": oracle_iei_wd,
                "oracle_cmp_wd": oracle_cmp_wd,
            },
            f,
        )
    print(f"Saved per-example oracle metrics to {oracle_metrics_path}")

    # ---- Print summary stats for MASEs: mean, median, 90th, 95th percentiles ----

    mase_sets = [
        ("IEI MASE", mase_iei),
        ("CMP MASE", mase_cmp),
        ("Series MASE", mase_series),
    ]

    def cdf_xy(vals: np.ndarray, nbins: int = 200):
        """Return (x, CDF(x)) for positive values, robustly binned."""
        v = vals[np.isfinite(vals)]
        v = v[v >= 0]
        if v.size == 0:
            return None
        # robust bin edges (linear works fine for MASE; switch to log if super long-tail)
        lo, hi = np.percentile(v, 0.0), np.percentile(v, 99.9)
        if hi <= lo:
            hi = lo + 1e-8
        edges = np.linspace(lo, hi, nbins + 1)
        counts, edges = np.histogram(v, bins=edges, density=True)
        cdf = np.cumsum(counts * np.diff(edges))
        return edges[1:], cdf

    def ccdf_xy(vals: np.ndarray, nbins: int = 200):
        out = cdf_xy(vals, nbins=nbins)
        if out is None:
            return None
        x, cdf = out
        return x, (1.0 - cdf)

    series = [
        (mase_series, "MASE (Merged series)", "tab:blue"),
        (mase_iei,    "MASE (IEI)",           "tab:orange"),
        (mase_cmp,    "MASE (CMP)",           "tab:purple"),
    ]

    # ---- CDF ----
    plt.figure(figsize=(7,5))
    for arr, label, color in series:
        out = cdf_xy(arr)
        if out is None:
            continue
        x, y = out
        plt.step(x, y, where="post", label=label, color=color, alpha=0.9)
    plt.xlabel("MASE (per example)")
    plt.ylabel("CDF")
    plt.title("CDF of MASE across examples")
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "mase_cdf_per_example.png"))

    # ---- CCDF ----
    plt.figure(figsize=(7,5))
    for arr, label, color in series:
        out = ccdf_xy(arr)
        if out is None:
            continue
        x, y = out
        plt.step(x, y, where="post", label=label, color=color, alpha=0.9)
    plt.xlabel("MASE (per example)")
    plt.ylabel("CCDF (1 - CDF)")
    plt.title("CCDF of MASE across examples")
    plt.yscale("log")   # CCDF naturally benefits from log-Y to show tails
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "mase_ccdf_per_example.png"))

    ase_store = {
        "ae_iei":   [r["ae_iei"]   for r in results if "ae_iei" in r],
        "ae_cmp":   [r["ae_cmp"]   for r in results if "ae_cmp" in r],
        "ae_series":[r["ae_series"]for r in results if "ae_series" in r],
    }
    ase_out_path = os.path.join(output_path, "per_example_ase.pkl")
    with open(ase_out_path, "wb") as f:
        pickle.dump(ase_store, f)
    print(f"Saved ASE arrays to {ase_out_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--iei_timeseries',    required=True)
    parser.add_argument('--compressed_timeseries', required=True)
    parser.add_argument('--output_path',       required=True)
    parser.add_argument('--threshold', type=float, default=0.0)
    parser.add_argument('--n_jobs', type=int, default=None,
                        help="# of parallel processes (default=CPU count)")
    args = parser.parse_args()

    merge_timeseries(
        args.iei_timeseries,
        args.compressed_timeseries,
        args.output_path,
        threshold=args.threshold,
        n_jobs=args.n_jobs
    )

if __name__ == "__main__":
    main()
