import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import importlib
from calflops import calculate_flops
from utils import *
from dataloader import TimeSeriesDataset, TimeSeriesLoader
from types import SimpleNamespace
from matplotlib.ticker import FuncFormatter
from types import SimpleNamespace
from analysis.model_flops_calculation import ML_Flops, model_flops

ML_MODELS = {'HBOS','LODA','ABOD','PCA','LOF','Hotelling','IForest','HSTree','CBLOF'}
DL_MODELS = {'USAD','DAGMM','LUAD','lstmAE','lstmVAE','OmniAnomaly','DeepSVDD', 'AnomalyTransformer','TimesNet'}

PRETTY_LABEL = {
    "lstmAE":  "LSTM-AE",
    "lstmVAE": "LSTM-VAE",
    "HSTree":  "HS-Tree",
    "AnomalyTransformer": "Anomaly Transformer"
}

ALL_MODELS = sorted(list(ML_MODELS | DL_MODELS))

_PALETTE = (list(plt.get_cmap('tab10').colors)
            + list(plt.get_cmap('Set2').colors)
            + list(plt.get_cmap('Dark2').colors)
            + list(plt.get_cmap('tab20').colors))

_MARKERS = ['o','s','D','^','v','>','<','P','X','*','h','H','8','p','d','1','2','3','4','|','_']


def build_global_model_styles(all_models=ALL_MODELS):
    styles = {}
    for i, m in enumerate(all_models):
        styles[m] = {
            "color": _PALETTE[i % len(_PALETTE)],
            "marker": _MARKERS[i % len(_MARKERS)],
            "linestyle": ('--' if m in DL_MODELS else '-'),
            "linewidth": (2.3 if m in DL_MODELS else 1.6),
            "label": PRETTY_LABEL.get(m, m),
        }
    return styles

MODEL_STYLES = build_global_model_styles()  

def _ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def _safe_int(x):
    return int(max(1, round(x)))

def _slice_dataset(ds, n_rows):
    n_rows = min(n_rows, ds.data.shape[0])
    data = ds.data[:n_rows]
    return SimpleNamespace(
        data=data,
        train_len=data.shape[0],
        test_len=data.shape[0],
        input_dim=(data.shape[1] if data.ndim > 1 else 1),
    )

def run_scalability_experiment(
    datasets,
    ml_model_names,
    dl_model_names,
    percents=None,
    save_dir="./analysis/results",
    save_csv_name="scalability_flops.csv",
    size_mode="split",
):

    if percents is None:
        percents = list(range(10, 101, 10))

    _ensure_dir(save_dir)
    rows = []

    ds_meta = {}
    for dataset in datasets:
        tr_ds = TimeSeriesDataset(dataset, train=True)
        te_ds = TimeSeriesDataset(dataset, train=False)
        tr = tr_ds.train_len
        te = te_ds.test_len
        tot = int(tr + te)
        ds_meta[dataset] = dict(
            train_len=tr, test_len=te, total_len=tot,
            train_ratio=(tr / tot if tot > 0 else 0.5),
            input_dim=tr_ds.input_dim,
        )

    target_map = {}
    for dataset in datasets:
        tr_len = ds_meta[dataset]["train_len"]
        te_len = ds_meta[dataset]["test_len"]

        if size_mode == "min":
            base = max(1, min(tr_len, te_len))
            train_targets = [_safe_int(base * (p/100.0)) for p in percents]
            test_targets  = train_targets[:]
        else:
            train_targets = [_safe_int(tr_len * (p/100.0)) for p in percents]
            test_targets  = [_safe_int(te_len * (p/100.0)) for p in percents]

        train_targets = [max(1, t) for t in train_targets]
        test_targets  = [max(1, t) for t in test_targets]
        target_map[dataset] = (train_targets, test_targets)

    for dataset in datasets:
        print(f"[SCALING] dataset={dataset}")
        runner = model_flops(dataset)

        base_train_ds = TimeSeriesDataset(dataset, train=True)
        base_test_ds  = TimeSeriesDataset(dataset, train=False)
        ml_cfg_root = ModelConfig("ML_calflops").get_param(dataset)

        base_train_len = ds_meta[dataset]["train_len"]
        base_test_len  = ds_meta[dataset]["test_len"]

        dl_cache = {}
        for dl_model_name in dl_model_names:
            try:
                train_epoch_flops, inference_epoch_flops, total_flops, full_training_flops = runner.flops(dl_model_name)
                dl_cache[dl_model_name] = dict(
                    train_epoch_flops=train_epoch_flops,
                    inference_epoch_flops=inference_epoch_flops,
                    full_training_flops=full_training_flops,
                    base_train_len=base_train_len,
                    base_test_len=base_test_len,
                )
            except Exception as e:
                print(f"[DL][SCALING][ERROR] dataset={dataset}, model={dl_model_name}: {e}")
                dl_cache[dl_model_name] = None

        train_targets, test_targets = target_map[dataset]

        for idx in range(len(percents)):
            n_tr = train_targets[idx]
            n_te = test_targets[idx]
            d = base_train_ds.input_dim

            for dl_model_name in dl_model_names:
                cache = dl_cache.get(dl_model_name, None)
                if cache is None:
                    rows.append({
                        "dataset": dataset, "model_name": dl_model_name,
                        "percent": percents[idx],
                        "train_size": n_tr, "test_size": n_te,
                        "train_flops": np.nan, "inference_flops": np.nan, "total_flops": np.nan
                    })
                    continue

                base_tr_len = cache["base_train_len"]
                base_te_len = cache["base_test_len"]

                s_tr = (n_tr / base_tr_len) if base_tr_len > 0 else 0.0
                s_te = (n_te / base_te_len) if base_te_len > 0 else 0.0

                train_flops_full = cache["full_training_flops"] * s_tr
                inference_flops  = cache["inference_epoch_flops"] * s_te

                rows.append({
                    "dataset": dataset, "model_name": dl_model_name,
                    "percent": percents[idx],
                    "train_size": n_tr, "test_size": n_te,
                    "train_flops": train_flops_full,
                    "inference_flops": inference_flops,
                    "total_flops": train_flops_full + inference_flops
                })

            for ml_model_name in ml_model_names:
                try:
                    cfg = dict(ml_cfg_root[ml_model_name])

                    if ml_model_name == "IForest":
                        psi = max(2, min(cfg["max_samples"], n_tr))
                        args = dict(num_train=n_tr, num_inference=n_te,
                                    n_estimators=cfg["n_estimators"], max_samples=psi)
                        flops_train, flops_inf = ML_Flops.FLOPs_IForest(**args)

                    elif ml_model_name == "HSTree":
                        psi = min(cfg["ref_window_size"], n_tr)
                        args = dict(num_train=n_tr, num_inference=n_te,
                                    n_estimators=cfg["n_estimators"],
                                    max_depth=cfg["max_depth"], ref_window_size=psi)
                        flops_train, flops_inf = ML_Flops.FLOPs_HSTree(**args)

                    elif ml_model_name == "CBLOF":
                        tr_subset = _slice_dataset(base_train_ds, n_tr)
                        te_subset = _slice_dataset(base_test_ds,  n_te)
                        model_params = cfg.get('model_params', cfg)
                        args = dict(train_set=tr_subset, test_set=te_subset, model_params=model_params)
                        flops_train, flops_inf = ML_Flops.FLOPs_CBLOF(**args)

                    else:
                        args = {"num_train": n_tr, "num_inference": n_te, "input_dim": d}
                        if ml_model_name == "HBOS":
                            args.update(bins=cfg["bins"])
                            flops_train, flops_inf = ML_Flops.FLOPs_HBOS(**args)
                        elif ml_model_name == "LODA":
                            args.update(n_random_cuts=cfg["n_random_cuts"], bins=cfg["bins"])
                            flops_train, flops_inf = ML_Flops.FLOPs_LODA(**args)
                        elif ml_model_name == "ABOD":
                            k_eff = min(cfg["n_neighbors"], max(1, n_tr - 1), max(1, n_te - 1))
                            args.update(n_neighbors=k_eff)
                            if n_tr < 2 or n_te < 2:
                                flops_train, flops_inf = (np.nan, np.nan)
                            else:
                                flops_train, flops_inf = ML_Flops.FLOPs_ABOD(**args)
                        elif ml_model_name == "PCA":
                            args.update(n_components=cfg["n_components"])
                            flops_train, flops_inf = ML_Flops.FLOPs_PCA(**args)
                        elif ml_model_name == "LOF":
                            k = cfg["n_neighbors"]
                            k_eff = min(k, max(1, n_tr-1))
                            args.update(n_neighbors=k_eff)
                            if n_tr < 2 or n_te < 2:
                                flops_train, flops_inf = (np.nan, np.nan)
                            else:
                                flops_train, flops_inf = ML_Flops.FLOPs_LOF(**args)
                        elif ml_model_name == "Hotelling":
                            flops_train, flops_inf = ML_Flops.FLOPs_Hotelling(n_tr, n_te, d)
                        else:
                            raise KeyError(f"Unsupported ML model for scaling: {ml_model_name}")

                    rows.append({
                        "dataset": dataset, "model_name": ml_model_name,
                        "percent": percents[idx],
                        "train_size": n_tr, "test_size": n_te,
                        "train_flops": flops_train,
                        "inference_flops": flops_inf,
                        "total_flops": (flops_train + flops_inf) if (np.isfinite(flops_train) and np.isfinite(flops_inf)) else np.nan
                    })

                except Exception as e:
                    print(f"[ML][SCALING][ERROR] dataset={dataset}, model={ml_model_name}, p={percents[idx]}: {e}")
                    rows.append({
                        "dataset": dataset, "model_name": ml_model_name,
                        "percent": percents[idx],
                        "train_size": n_tr, "test_size": n_te,
                        "train_flops": np.nan, "inference_flops": np.nan, "total_flops": np.nan
                    })

    df_all = pd.DataFrame(rows)
    csv_path = os.path.join(save_dir, save_csv_name)
    df_all.to_csv(csv_path, index=False)
    print(f"[CSV] saved: {csv_path}")

    return df_all, ds_meta, target_map


def plot_one_dataset_row(
    df_all: pd.DataFrame,
    ds_meta: dict,
    dataset: str,
    save_dir="./analysis/results",
    logy=True,
    legend_lines=None,
):
    if dataset not in ds_meta:
        raise ValueError(f"Unknown dataset: {dataset}")

    df_ds = df_all[df_all["dataset"] == dataset].copy()
    df_ds = df_ds[np.isfinite(df_ds["train_flops"]) | np.isfinite(df_ds["inference_flops"])]

    x_tr_min, x_tr_max = int(df_ds["train_size"].min()), int(df_ds["train_size"].max())
    x_te_min, x_te_max = int(df_ds["test_size"].min()),  int(df_ds["test_size"].max())

    def _ylim(series1, series2, logy):
        y_pos = (pd.concat([series1, series2]) / 1e9)
        y_pos = y_pos[(y_pos > 0) & np.isfinite(y_pos)]
        if len(y_pos) == 0:
            return (1.0, 1.0) if logy else (0.0, 1.0)
        y_min = float(y_pos.min()); y_max = float(y_pos.max())
        return ((y_min/1.5, y_max*1.2) if logy else (0.0, y_max*1.05))

    y_min, y_max = _ylim(df_ds["train_flops"], df_ds["inference_flops"], logy)

    unique_models = sorted(df_ds["model_name"].dropna().unique())

    fig, (ax_tr, ax_in) = plt.subplots(1, 2, figsize=(18, 7.5), sharex=False, sharey=False, gridspec_kw={'wspace': 0.1})

    def _fmt(ax, title_suffix, xlim, ylim, show_ylabel=False, xlabel="Data size (instances)"):
        ax.set_xlim(*xlim)
        if logy: ax.set_yscale("log")
        ax.set_ylim(*ylim)
        ax.minorticks_off()
        ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.6)
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.2f}" if x <= 1.0 else f"{int(x)}"))
        ax.tick_params(axis='both', which='major', labelsize=18)
        ax.tick_params(axis='x', labelrotation=15)
        dim = ds_meta[dataset]["input_dim"]
        ax.set_title(f"{title_suffix} — {dataset} (dim={dim})", pad=8, fontdict={'fontsize': 20})
        ax.set_xlabel(xlabel, fontdict={'fontsize': 20})
        if show_ylabel: ax.set_ylabel("GFLOPs", fontdict={'fontsize': 20})
        ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)

    def _plot(ax, ycol, xcol):
        for model in unique_models:
            sub = df_ds[df_ds["model_name"] == model].sort_values(xcol)
            if sub.empty:
                continue
            st = MODEL_STYLES.get(model, {
                "color": None, "marker": 'o', "linestyle": '-', "linewidth": 1.6, "label": PRETTY_LABEL.get(model, model)
            })
            ax.plot(
                sub[xcol].values,
                (sub[ycol].values / 1e9),
                label=st["label"],
                linestyle=st["linestyle"],
                linewidth=st["linewidth"],
                marker=st["marker"],
                ms=10,
                color=st["color"],
                alpha=0.95
            )

    _plot(ax_tr, "train_flops", "train_size")
    _plot(ax_in, "inference_flops", "test_size")

    _fmt(ax_tr, "Train FLOPs", (x_tr_min, x_tr_max), (y_min, y_max), show_ylabel=True,  xlabel="Data size (instances)")
    _fmt(ax_in, "Inference FLOPs", (x_te_min, x_te_max), (y_min, y_max), show_ylabel=False, xlabel="Data size (instances)")

    handles, labels = [], []
    for ax in (ax_tr, ax_in):
        h, l = ax.get_legend_handles_labels()
        handles += h
        labels += l

    seen = set()
    h_uniq = []
    l_uniq = []
    for h, l in zip(handles, labels):
        if l not in seen:
            seen.add(l)
            h_uniq.append(h)
            l_uniq.append(l)

    if legend_lines is not None:
        label_to_handle = {lab: h for h, lab in zip(l_uniq, l_uniq)}

        ordered_handles = []
        ordered_labels = []

        for row in legend_lines:
            for m in row:
                lab = PRETTY_LABEL.get(m, m)
                h = None
                for hh, ll in zip(h_uniq, l_uniq):
                    if ll == lab:
                        h = hh
                        break
                if h is not None:
                    ordered_handles.append(h)
                    ordered_labels.append(lab)

        if ordered_handles:
            h_uniq = ordered_handles
            l_uniq = ordered_labels

        ncol = len(legend_lines[0]) if legend_lines and legend_lines[0] else len(h_uniq)
    else:
        ncol = min(5, len(h_uniq))

    fig.subplots_adjust(left=0.10, right=0.98, top=0.78, bottom=0.16, wspace=0.12)
    fig.legend(
        h_uniq, l_uniq,
        loc="lower center", bbox_to_anchor=(0.53, 0.85),
        ncol=ncol, frameon=False,
        columnspacing=1.2, handletextpad=0.6,
        borderaxespad=0.2, fontsize=20
    )
    ax_in.tick_params(labelleft=False)

    os.makedirs(save_dir, exist_ok=True)
    out = os.path.join(save_dir, f"scalability_{dataset}_row.pdf")
    fig.savefig(out, dpi=600, bbox_inches="tight", pad_inches=0.20)
    plt.close(fig)
    print(f"[PLOT] saved: {out}")


def _slice_dataset_dim(ds, n_rows, feature_idx):
    n_rows = min(n_rows, ds.data.shape[0])
    data = ds.data[:n_rows]

    if data.ndim == 1:
        data = data[:, None]

    feature_idx = np.array(feature_idx, dtype=int)
    data = data[:, feature_idx]

    return SimpleNamespace(
        data=data,
        train_len=data.shape[0],
        test_len=data.shape[0],
        input_dim=(data.shape[1] if data.ndim > 1 else 1),
    )

def run_dim_scalability_experiment(
    datasets,
    ml_model_names,
    dl_model_names,
    dim_percents=None,
    save_dir="./analysis/results",
    save_csv_name="scalability_dim_flops.csv",
    seed=42,
):
    if dim_percents is None:
        dim_percents = list(range(10, 101, 10))

    _ensure_dir(save_dir)
    rows = []

    ds_meta = {}
    for dataset in datasets:
        tr_ds = TimeSeriesDataset(dataset, train=True)
        te_ds = TimeSeriesDataset(dataset, train=False)
        tr = tr_ds.train_len
        te = te_ds.test_len
        tot = int(tr + te)
        ds_meta[dataset] = dict(
            train_len=tr,
            test_len=te,
            total_len=tot,
            train_ratio=(tr / tot if tot > 0 else 0.5),
            input_dim=tr_ds.input_dim,
        )

    rng = np.random.default_rng(seed)

    for dataset in datasets:
        print(f"[DIM-SCALING] dataset={dataset}")
        runner = model_flops(dataset)

        base_train_ds = TimeSeriesDataset(dataset, train=True)
        base_test_ds  = TimeSeriesDataset(dataset, train=False)
        ml_cfg_root = ModelConfig("ML_calflops").get_param(dataset)

        base_train_len = ds_meta[dataset]["train_len"]
        base_test_len  = ds_meta[dataset]["test_len"]
        base_dim       = ds_meta[dataset]["input_dim"]

        n_tr = base_train_len
        n_te = base_test_len
        print(f"  -> using n_tr={n_tr}, n_te={n_te}")

        dl_cache = {}
        for dl_model_name in dl_model_names:
            try:
                tr_epoch_flops, inf_epoch_flops, total_flops, full_training_flops = runner.flops(dl_model_name)
                dl_cache[dl_model_name] = dict(
                    train_epoch_flops=tr_epoch_flops,
                    inference_epoch_flops=inf_epoch_flops,
                    full_training_flops=full_training_flops,
                    base_train_len=base_train_len,
                    base_test_len=base_test_len,
                    base_dim=base_dim,
                )
            except Exception as e:
                print(f"[DL][DIM-SCALING][ERROR] dataset={dataset}, model={dl_model_name}: {e}")
                dl_cache[dl_model_name] = None

        for p in dim_percents:
            d_sub = _safe_int(base_dim * (p / 100.0))
            d_sub = max(1, min(d_sub, base_dim))

            feature_idx = rng.choice(base_dim, size=d_sub, replace=False)

            for dl_model_name in dl_model_names:
                cache = dl_cache.get(dl_model_name, None)
                if cache is None:
                    rows.append({
                        "dataset": dataset,
                        "model_name": dl_model_name,
                        "dim_percent": p,
                        "dim": d_sub,
                        "train_size": n_tr,
                        "test_size": n_te,
                        "train_flops": np.nan,
                        "inference_flops": np.nan,
                        "total_flops": np.nan,
                    })
                    continue

                base_tr_len = cache["base_train_len"]
                base_te_len = cache["base_test_len"]
                base_d      = cache["base_dim"]

                s_tr_n = (n_tr / base_tr_len) if base_tr_len > 0 else 0.0
                s_te_n = (n_te / base_te_len) if base_te_len > 0 else 0.0
                s_d    = (d_sub / base_d) if base_d > 0 else 1.0

                train_flops_full = cache["full_training_flops"] * s_tr_n * s_d
                inference_flops  = cache["inference_epoch_flops"] * s_te_n * s_d

                rows.append({
                    "dataset": dataset,
                    "model_name": dl_model_name,
                    "dim_percent": p,
                    "dim": d_sub,
                    "train_size": n_tr,
                    "test_size": n_te,
                    "train_flops": train_flops_full,
                    "inference_flops": inference_flops,
                    "total_flops": train_flops_full + inference_flops,
                })

            for ml_model_name in ml_model_names:
                try:
                    cfg = dict(ml_cfg_root[ml_model_name])
                    d = d_sub

                    if ml_model_name == "IForest":
                        psi = max(2, min(cfg["max_samples"], n_tr))
                        args = dict(
                            num_train=n_tr,
                            num_inference=n_te,
                            n_estimators=cfg["n_estimators"],
                            max_samples=psi,
                        )
                        flops_train, flops_inf = ML_Flops.FLOPs_IForest(**args)

                    elif ml_model_name == "HSTree":
                        psi = min(cfg["ref_window_size"], n_tr)
                        args = dict(
                            num_train=n_tr,
                            num_inference=n_te,
                            n_estimators=cfg["n_estimators"],
                            max_depth=cfg["max_depth"],
                            ref_window_size=psi,
                        )
                        flops_train, flops_inf = ML_Flops.FLOPs_HSTree(**args)

                    elif ml_model_name == "CBLOF":
                        tr_subset = _slice_dataset_dim(base_train_ds, n_tr, feature_idx)
                        te_subset = _slice_dataset_dim(base_test_ds,  n_te, feature_idx)
                        model_params = cfg.get('model_params', cfg)
                        args = dict(
                            train_set=tr_subset,
                            test_set=te_subset,
                            model_params=model_params,
                        )
                        flops_train, flops_inf = ML_Flops.FLOPs_CBLOF(**args)

                    else:
                        args = {"num_train": n_tr, "num_inference": n_te, "input_dim": d}
                        if ml_model_name == "HBOS":
                            args.update(bins=cfg["bins"])
                            flops_train, flops_inf = ML_Flops.FLOPs_HBOS(**args)
                        elif ml_model_name == "LODA":
                            args.update(n_random_cuts=cfg["n_random_cuts"], bins=cfg["bins"])
                            flops_train, flops_inf = ML_Flops.FLOPs_LODA(**args)
                        elif ml_model_name == "ABOD":
                            k_eff = min(cfg["n_neighbors"], max(1, n_tr - 1), max(1, n_te - 1))
                            args.update(n_neighbors=k_eff)
                            if n_tr < 2 or n_te < 2:
                                flops_train, flops_inf = (np.nan, np.nan)
                            else:
                                flops_train, flops_inf = ML_Flops.FLOPs_ABOD(**args)
                        elif ml_model_name == "PCA":
                            args.update(n_components=cfg["n_components"])
                            flops_train, flops_inf = ML_Flops.FLOPs_PCA(**args)
                        elif ml_model_name == "LOF":
                            k = cfg["n_neighbors"]
                            k_eff = min(k, max(1, n_tr - 1))
                            args.update(n_neighbors=k_eff)
                            if n_tr < 2 or n_te < 2:
                                flops_train, flops_inf = (np.nan, np.nan)
                            else:
                                flops_train, flops_inf = ML_Flops.FLOPs_LOF(**args)
                        elif ml_model_name == "Hotelling":
                            flops_train, flops_inf = ML_Flops.FLOPs_Hotelling(n_tr, n_te, d)
                        else:
                            raise KeyError(f"Unsupported ML model for dim scaling: {ml_model_name}")

                    rows.append({
                        "dataset": dataset,
                        "model_name": ml_model_name,
                        "dim_percent": p,
                        "dim": d_sub,
                        "train_size": n_tr,
                        "test_size": n_te,
                        "train_flops": flops_train,
                        "inference_flops": flops_inf,
                        "total_flops": (flops_train + flops_inf)
                                       if (np.isfinite(flops_train) and np.isfinite(flops_inf))
                                       else np.nan,
                    })

                except Exception as e:
                    print(f"[ML][DIM-SCALING][ERROR] dataset={dataset}, model={ml_model_name}, p={p}: {e}")
                    rows.append({
                        "dataset": dataset,
                        "model_name": ml_model_name,
                        "dim_percent": p,
                        "dim": d_sub,
                        "train_size": n_tr,
                        "test_size": n_te,
                        "train_flops": np.nan,
                        "inference_flops": np.nan,
                        "total_flops": np.nan,
                    })

    df_all = pd.DataFrame(rows)
    csv_path = os.path.join(save_dir, save_csv_name)
    df_all.to_csv(csv_path, index=False)
    print(f"[CSV] saved: {csv_path}")

    return df_all, ds_meta

def plot_one_dataset_dim(
    df_all: pd.DataFrame,
    ds_meta: dict,
    dataset: str,
    save_dir="./analysis/results",
    logy=True,
    legend_lines=None,
):
    if dataset not in ds_meta:
        raise ValueError(f"Unknown dataset: {dataset}")

    df_ds = df_all[df_all["dataset"] == dataset].copy()
    df_ds = df_ds[np.isfinite(df_ds["train_flops"]) | np.isfinite(df_ds["inference_flops"])]

    if df_ds.empty:
        print(f"[WARN] No data to plot for dataset={dataset}")
        return

    x_dim_min, x_dim_max = int(df_ds["dim"].min()), int(df_ds["dim"].max())

    def _ylim(series1, series2, logy):
        y_pos = (pd.concat([series1, series2]) / 1e9)  # GFLOPs
        y_pos = y_pos[(y_pos > 0) & np.isfinite(y_pos)]
        if len(y_pos) == 0:
            return (1.0, 1.0) if logy else (0.0, 1.0)
        y_min = float(y_pos.min()); y_max = float(y_pos.max())
        return ((y_min/1.5, y_max*1.2) if logy else (0.0, y_max*1.05))

    y_min, y_max = _ylim(df_ds["train_flops"], df_ds["inference_flops"], logy)

    unique_models = sorted(df_ds["model_name"].dropna().unique())

    fig, (ax_tr, ax_in) = plt.subplots(
        1, 2, figsize=(18, 7.5),
        sharex=False, sharey=False,
        gridspec_kw={'wspace': 0.1}
    )

    def _fmt(ax, title_suffix, xlim, ylim, show_ylabel=False, xlabel="Dimension (features)"):
        ax.set_xlim(*xlim)
        if logy:
            ax.set_yscale("log")
        ax.set_ylim(*ylim)
        ax.minorticks_off()
        ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.6)
        ax.yaxis.set_major_formatter(
            FuncFormatter(lambda x, _: f"{x:.2f}" if x <= 1.0 else f"{int(x)}")
        )
        ax.tick_params(axis='both', which='major', labelsize=18)
        ax.tick_params(axis='x', labelrotation=15)

        base_dim = ds_meta[dataset]["input_dim"]
        ax.set_title(
            f"{title_suffix} — {dataset} (base dim={base_dim})",
            pad=8,
            fontdict={'fontsize': 20}
        )
        ax.set_xlabel(xlabel, fontdict={'fontsize': 20})
        if show_ylabel:
            ax.set_ylabel("GFLOPs", fontdict={'fontsize': 20})
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    def _plot(ax, ycol, xcol):
        for model in unique_models:
            sub = df_ds[df_ds["model_name"] == model].sort_values(xcol)
            if sub.empty:
                continue
            st = MODEL_STYLES.get(
                model,
                {
                    "color": None,
                    "marker": 'o',
                    "linestyle": '-',
                    "linewidth": 1.6,
                    "label": PRETTY_LABEL.get(model, model)
                },
            )
            ax.plot(
                sub[xcol].values,
                (sub[ycol].values / 1e9),        # GFLOPs
                label=st["label"],
                linestyle=st["linestyle"],
                linewidth=st["linewidth"],
                marker=st["marker"],
                ms=10,
                color=st["color"],
                alpha=0.95,
            )

    _plot(ax_tr, "train_flops", "dim")
    _plot(ax_in, "inference_flops", "dim")

    _fmt(
        ax_tr,
        "Train FLOPs",
        (x_dim_min, x_dim_max),
        (y_min, y_max),
        show_ylabel=True,
        xlabel="Dimension (features)",
    )
    _fmt(
        ax_in,
        "Inference FLOPs",
        (x_dim_min, x_dim_max),
        (y_min, y_max),
        show_ylabel=False,
        xlabel="Dimension (features)",
    )

    handles, labels = [], []
    for ax in (ax_tr, ax_in):
        h, l = ax.get_legend_handles_labels()
        handles += h
        labels += l

    seen = set()
    h_uniq, l_uniq = [], []
    for h, l in zip(handles, labels):
        if l not in seen:
            seen.add(l)
            h_uniq.append(h)
            l_uniq.append(l)

    if legend_lines is not None:
        ordered_handles = []
        ordered_labels = []
        for row in legend_lines:
            for m in row:
                lab = PRETTY_LABEL.get(m, m)
                h = None
                for hh, ll in zip(h_uniq, l_uniq):
                    if ll == lab:
                        h = hh
                        break
                if h is not None:
                    ordered_handles.append(h)
                    ordered_labels.append(lab)

        if ordered_handles:
            h_uniq = ordered_handles
            l_uniq = ordered_labels

        ncol = len(legend_lines[0]) if legend_lines and legend_lines[0] else len(h_uniq)
    else:
        ncol = min(5, len(h_uniq))

    fig.subplots_adjust(left=0.10, right=0.98, top=0.78, bottom=0.16, wspace=0.12)
    fig.legend(
        h_uniq,
        l_uniq,
        loc="lower center",
        bbox_to_anchor=(0.53, 0.85),
        ncol=ncol,
        frameon=False,
        columnspacing=1.2,
        handletextpad=0.6,
        borderaxespad=0.2,
        fontsize=20,
    )
    ax_in.tick_params(labelleft=False)

    os.makedirs(save_dir, exist_ok=True)
    out = os.path.join(save_dir, f"scalability_dim_{dataset}_row.pdf")
    fig.savefig(out, dpi=600, bbox_inches="tight", pad_inches=0.20)
    plt.close(fig)
    print(f"[PLOT] saved: {out}")

def plot_scalability_combined(
    df_size: pd.DataFrame,     # run_scalability_experiment
    df_dim: pd.DataFrame,      # run_dim_scalability_experiment
    ds_meta: dict,
    dataset: str,
    save_dir="./analysis/results",
    logy=True,
    legend_lines=None,
):

    if dataset not in ds_meta:
        raise ValueError(f"Unknown dataset: {dataset}")

    df_size_ds = df_size[df_size["dataset"] == dataset].copy()
    df_size_ds = df_size_ds[np.isfinite(df_size_ds["train_flops"]) | np.isfinite(df_size_ds["inference_flops"])]

    df_dim_ds = df_dim[df_dim["dataset"] == dataset].copy()
    df_dim_ds = df_dim_ds[np.isfinite(df_dim_ds["train_flops"]) | np.isfinite(df_dim_ds["inference_flops"])]

    if df_size_ds.empty:
        print(f"[WARN] No size-scaling data to plot for dataset={dataset}")
        return
    if df_dim_ds.empty:
        print(f"[WARN] No dim-scaling data to plot for dataset={dataset}")
        return

    x_tr_min, x_tr_max = int(df_size_ds["train_size"].min()), int(df_size_ds["train_size"].max())
    x_te_min, x_te_max = int(df_size_ds["test_size"].min()),  int(df_size_ds["test_size"].max())
    x_dim_min, x_dim_max = int(df_dim_ds["dim"].min()), int(df_dim_ds["dim"].max())

    def _ylim(*series_list, logy=True):
        y_pos = pd.concat(series_list) / 1e9  # GFLOPs
        y_pos = y_pos[(y_pos > 0) & np.isfinite(y_pos)]
        if len(y_pos) == 0:
            return (1.0, 1.0) if logy else (0.0, 1.0)
        y_min = float(y_pos.min()); y_max = float(y_pos.max())
        return ((y_min/1.5, y_max*1.2) if logy else (0.0, y_max*1.05))

    y_min, y_max = _ylim(
        df_size_ds["train_flops"], df_size_ds["inference_flops"],
        df_dim_ds["train_flops"],  df_dim_ds["inference_flops"],
        logy=logy
    )

    unique_models = sorted(
        set(df_size_ds["model_name"].dropna().unique()).union(
            set(df_dim_ds["model_name"].dropna().unique())
        )
    )

    # ----- Figure & Axes -----
    fig, axes = plt.subplots(
        2, 2,
        figsize=(18, 12),
        sharex=False, sharey=False,
        gridspec_kw={"hspace": 0.28, "wspace": 0.10}
    )
    ax_tr_size, ax_in_size = axes[0]
    ax_tr_dim,  ax_in_dim  = axes[1]

    def _fmt(ax, title, xlim, ylim, show_ylabel=False, xlabel=""):
        ax.set_xlim(*xlim)
        if logy:
            ax.set_yscale("log")
        ax.set_ylim(*ylim)
        ax.minorticks_off()
        ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.6)
        ax.yaxis.set_major_formatter(
            FuncFormatter(lambda x, _: f"{x:.2f}" if x <= 1.0 else f"{int(x)}")
        )
        ax.tick_params(axis='both', which='major', labelsize=16)
        ax.tick_params(axis='x', labelrotation=15)
        ax.set_title(title, pad=6, fontsize=18)
        ax.set_xlabel(xlabel, fontsize=18)
        if show_ylabel:
            ax.set_ylabel("GFLOPs", fontsize=18)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    def _plot(df_source, ax, ycol, xcol):
        for model in unique_models:
            sub = df_source[df_source["model_name"] == model].sort_values(xcol)
            if sub.empty:
                continue
            st = MODEL_STYLES.get(
                model,
                {
                    "color": None,
                    "marker": 'o',
                    "linestyle": '-',
                    "linewidth": 1.6,
                    "label": PRETTY_LABEL.get(model, model),
                },
            )
            ax.plot(
                sub[xcol].values,
                (sub[ycol].values / 1e9),
                label=st["label"],
                linestyle=st["linestyle"],
                linewidth=st["linewidth"],
                marker=st["marker"],
                ms=8,
                color=st["color"],
                alpha=0.95,
            )

    _plot(df_size_ds, ax_tr_size, "train_flops", "train_size")
    _plot(df_size_ds, ax_in_size, "inference_flops", "test_size")

    _fmt(
        ax_tr_size,
        "Train FLOPs vs Instances",
        (x_tr_min, x_tr_max),
        (y_min, y_max),
        show_ylabel=True,
        xlabel="Data size (instances)",
    )
    _fmt(
        ax_in_size,
        "Inference FLOPs vs Instances",
        (x_te_min, x_te_max),
        (y_min, y_max),
        show_ylabel=False,
        xlabel="Data size (instances)",
    )

    _plot(df_dim_ds, ax_tr_dim, "train_flops", "dim")
    _plot(df_dim_ds, ax_in_dim, "inference_flops", "dim")

    _fmt(
        ax_tr_dim,
        "Train FLOPs vs Dimension",
        (x_dim_min, x_dim_max),
        (y_min, y_max),
        show_ylabel=True,
        xlabel="Dimension (features)",
    )
    _fmt(
        ax_in_dim,
        "Inference FLOPs vs Dimension",
        (x_dim_min, x_dim_max),
        (y_min, y_max),
        show_ylabel=False,
        xlabel="Dimension (features)",
    )

    handles, labels = [], []
    for ax in (ax_tr_size, ax_in_size, ax_tr_dim, ax_in_dim):
        h, l = ax.get_legend_handles_labels()
        handles += h
        labels += l

    seen = set()
    h_uniq, l_uniq = [], []
    for h, l in zip(handles, labels):
        if l not in seen:
            seen.add(l)
            h_uniq.append(h)
            l_uniq.append(l)

    if legend_lines is not None:
        ordered_handles = []
        ordered_labels = []
        for row in legend_lines:
            for m in row:
                lab = PRETTY_LABEL.get(m, m)
                h = None
                for hh, ll in zip(h_uniq, l_uniq):
                    if ll == lab:
                        h = hh
                        break
                if h is not None:
                    ordered_handles.append(h)
                    ordered_labels.append(lab)
        if ordered_handles:
            h_uniq = ordered_handles
            l_uniq = ordered_labels
        ncol = len(legend_lines[0]) if legend_lines and legend_lines[0] else len(h_uniq)
    else:
        ncol = min(5, len(h_uniq))

    fig.subplots_adjust(left=0.10, right=0.98, top=0.93, bottom=0.14, hspace=0.30, wspace=0.12)
    fig.legend(
        h_uniq,
        l_uniq,
        loc="upper center",
        bbox_to_anchor=(0.53, 1.06),
        ncol=ncol,
        frameon=False,
        columnspacing=1.2,
        handletextpad=0.6,
        borderaxespad=0.2,
        fontsize=18,
    )

    ax_in_size.tick_params(labelleft=False)
    ax_in_dim.tick_params(labelleft=False)

    os.makedirs(save_dir, exist_ok=True)
    out = os.path.join(save_dir, f"scalability_combined_{dataset}.pdf")
    fig.savefig(out, dpi=600, bbox_inches="tight", pad_inches=0.20)
    plt.close(fig)
    print(f"[PLOT] saved: {out}")
# ==========================================================================================================================


datasets   = ['SMD', 'SMAP', 'MSL', 'SWaT', 'WADI', 'PSM']
dl_model_names = ['USAD', 'DAGMM', 'LUAD', 'lstmAE', 'lstmVAE', 'OmniAnomaly', 'DeepSVDD', 'AnomalyTransformer', 'TimesNet']
ml_model_names = ['HBOS', 'LODA', 'ABOD', 'PCA', 'LOF', 'Hotelling', 'IForest', 'HSTree', 'CBLOF']

dl_model_names_select = ['LUAD', 'lstmAE', 'OmniAnomaly', 'AnomalyTransformer', 'TimesNet']
ml_model_names_select = ['HBOS', 'LOF', 'IForest', 'HSTree', 'Hotelling']

dim_percents = list(range(10, 101, 10))   # 10%, 20%, ..., 100%

scalability_df_min_all, ds_meta_min_all, target_lengths_min_all = run_scalability_experiment(
    datasets=datasets,
    ml_model_names=ml_model_names,
    dl_model_names=dl_model_names,
    percents=list(range(10, 101, 10)),
    save_dir="./analysis/results/rebuttal",
    save_csv_name="scalability_flops_min_all.csv",
    size_mode="min"
)

scalability_df_min_select, ds_meta_min_select, target_lengths_min_select = run_scalability_experiment(
    datasets=datasets,
    ml_model_names=ml_model_names_select,
    dl_model_names=dl_model_names_select,
    percents=list(range(10, 101, 10)),
    save_dir="./analysis/results/rebuttal",
    save_csv_name="scalability_flops_min_select.csv",
    size_mode="min"
)

scalability_df_dim_all, ds_meta_dim_all = run_dim_scalability_experiment(
    datasets=datasets,
    ml_model_names=ml_model_names,
    dl_model_names=dl_model_names,
    dim_percents=dim_percents,
    save_dir="./analysis/results/rebuttal",
    save_csv_name="scalability_dim_flops_all.csv",
    seed=42,
)

scalability_df_dim_select, ds_meta_dim_select = run_dim_scalability_experiment(
    datasets=datasets,
    ml_model_names=ml_model_names_select,
    dl_model_names=dl_model_names_select,
    dim_percents=dim_percents,
    save_dir="./analysis/results/rebuttal",
    save_csv_name="scalability_dim_flops_select.csv",
    seed=42,
)

# =====================================================================
    
line1 = ['HBOS', 'AnomalyTransformer', 'Hotelling', 'lstmAE', 'HSTree']
line2 = ['LUAD','IForest', 'OmniAnomaly', 'LOF', 'TimesNet']
selected_lines = [line1, line2]


line3 = ['ABOD','LODA','lstmAE','CBLOF','LOF', 'lstmVAE']
line4 = ['HBOS', 'PCA', 'LUAD', 'Hotelling', 'AnomalyTransformer', 'OmniAnomaly']
line5 = ['HSTree', 'DAGMM', 'TimesNet', 'IForest', 'DeepSVDD', 'USAD']
all_lines = [line3, line4, line5]

# scalability_df_dim_all, ds_meta_dim_all = ...

for ds in datasets:
    plot_scalability_combined(
        df_size=scalability_df_min_all,
        df_dim=scalability_df_dim_all,
        ds_meta=ds_meta_min_all,         
        dataset=ds,
        save_dir="./analysis/results/rebuttal/scalability_combined_all",
        logy=True,
        legend_lines=all_lines,           
    )

for ds in datasets:
    plot_scalability_combined(
        df_size=scalability_df_min_select,
        df_dim=scalability_df_dim_select,
        ds_meta=ds_meta_min_select,          
        dataset=ds,
        save_dir="./analysis/results/rebuttal/scalability_combined_select",
        logy=True,
        legend_lines=selected_lines
    )


# for ds in datasets:
#     plot_one_dataset_row(
#         df_all=scalability_df_min_all,
#         ds_meta=ds_meta_min_all,
#         dataset=ds,
#         save_dir="./analysis/results/rebuttal/scalability_final_all",
#         logy=True,
#         legend_lines=all_lines
#     )

# for ds in datasets:
#     plot_one_dataset_row(
#         df_all=scalability_df_min_select,
#         ds_meta=ds_meta_min_select,
#         dataset=ds,
#         save_dir="./analysis/results/rebuttal/scalability_final_select",
#         logy=True,
#         legend_lines=selected_lines
#     )

# # Figure (all models)
# # =====================================================================
# for ds in datasets:
#     plot_one_dataset_dim(
#         df_all=scalability_df_dim_all,
#         ds_meta=ds_meta_dim_all,
#         dataset=ds,
#         save_dir="./analysis/results/rebuttal/scalability_dim_all",
#         logy=True,
#         legend_lines=all_lines
#     )

# # =====================================================================
# # Figure (select models)
# # =====================================================================
# for ds in datasets:
#     plot_one_dataset_dim(
#         df_all=scalability_df_dim_select,
#         ds_meta=ds_meta_dim_select,
#         dataset=ds,
#         save_dir="./analysis/results/rebuttal/scalability_dim_select",
#         logy=True,
#         legend_lines=selected_lines,
#     )
