import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import importlib
from calflops import calculate_flops
from utils import *
from dataloader import TimeSeriesDataset, TimeSeriesLoader
from types import SimpleNamespace
from matplotlib.ticker import FuncFormatter
from types import SimpleNamespace
from analysis.model_flops_calculation import ML_Flops, model_flops

ML_MODELS = {'HBOS','LODA','ABOD','PCA','LOF','Hotelling','IForest','HSTree','CBLOF'}
DL_MODELS = {'USAD','DAGMM','LUAD','lstmAE','lstmVAE','OmniAnomaly','DeepSVDD'}

PRETTY_LABEL = {
    "lstmAE":  "LSTM-AE",
    "lstmVAE": "LSTM-VAE",
    "HSTree":  "HS-Tree",
}

ALL_MODELS = sorted(list(ML_MODELS | DL_MODELS))

_PALETTE = (list(plt.get_cmap('tab10').colors)
            + list(plt.get_cmap('Set2').colors)
            + list(plt.get_cmap('Dark2').colors)
            + list(plt.get_cmap('tab20').colors))

_MARKERS = ['o','s','D','^','v','>','<','P','X','*','h','H','8','p','d','1','2','3','4','|','_']

def build_global_model_styles(all_models=ALL_MODELS):
    styles = {}
    for i, m in enumerate(all_models):
        styles[m] = {
            "color": _PALETTE[i % len(_PALETTE)],
            "marker": _MARKERS[i % len(_MARKERS)],
            "linestyle": ('--' if m in DL_MODELS else '-'),
            "linewidth": (2.3 if m in DL_MODELS else 1.6),
            "label": PRETTY_LABEL.get(m, m),
        }
    return styles

MODEL_STYLES = build_global_model_styles()  

def _ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def _safe_int(x):
    return int(max(1, round(x)))

def _slice_dataset(ds, n_rows):
    n_rows = min(n_rows, ds.data.shape[0])
    data = ds.data[:n_rows]
    return SimpleNamespace(
        data=data,
        train_len=data.shape[0],
        test_len=data.shape[0],
        input_dim=(data.shape[1] if data.ndim > 1 else 1),
    )

def run_scalability_experiment(
    datasets,
    ml_model_names,
    dl_model_names,
    percents=None,
    save_dir="./analysis/results",
    save_csv_name="scalability_flops.csv",
    size_mode="split",
):

    if percents is None:
        percents = list(range(10, 101, 10))

    _ensure_dir(save_dir)
    rows = []

    ds_meta = {}
    for dataset in datasets:
        tr_ds = TimeSeriesDataset(dataset, train=True)
        te_ds = TimeSeriesDataset(dataset, train=False)
        tr = tr_ds.train_len
        te = te_ds.test_len
        tot = int(tr + te)
        ds_meta[dataset] = dict(
            train_len=tr, test_len=te, total_len=tot,
            train_ratio=(tr / tot if tot > 0 else 0.5),
            input_dim=tr_ds.input_dim,
        )

    target_map = {}
    for dataset in datasets:
        tr_len = ds_meta[dataset]["train_len"]
        te_len = ds_meta[dataset]["test_len"]

        if size_mode == "min":
            base = max(1, min(tr_len, te_len))
            train_targets = [_safe_int(base * (p/100.0)) for p in percents]
            test_targets  = train_targets[:]  # 동일
        else:
            train_targets = [_safe_int(tr_len * (p/100.0)) for p in percents]
            test_targets  = [_safe_int(te_len * (p/100.0)) for p in percents]

        train_targets = [max(1, t) for t in train_targets]
        test_targets  = [max(1, t) for t in test_targets]
        target_map[dataset] = (train_targets, test_targets)

    for dataset in datasets:
        print(f"[SCALING] dataset={dataset}")
        runner = model_flops(dataset)

        dl_cache = {}
        try:
            base_cfgs = {}
            for dl_model_name in dl_model_names:
                config = ModelConfig(dl_model_name)
                loader_cfg, model_cfg, train_cfg = config.resolve(dataset)
                base_cfgs[dl_model_name] = (loader_cfg, dict(model_cfg), dict(train_cfg))

            for dl_model_name in dl_model_names:
                try:
                    loader_cfg, model_cfg, train_cfg = base_cfgs[dl_model_name]
                    loader = TimeSeriesLoader(dataset_name=dataset, **loader_cfg)
                    model_cfg["input_dim"] = loader.input_dim
                    input_dim = model_cfg["input_dim"]

                    if dl_model_name in ("DAGMM", "DeepSVDD"):
                        input_shape = (1, 1 * input_dim * loader_cfg["window_size"])
                    else:
                        input_shape = (1, 1, input_dim)

                    mod_name, cls_name = runner.MODEL_DISPATCH[dl_model_name]
                    module = importlib.import_module(mod_name)
                    detector_cls = getattr(module, cls_name)
                    detector = detector_cls(loader, **model_cfg)

                    F_tr_pass, _, _ = calculate_flops(
                        detector.model, input_shape=input_shape,
                        forward_mode="forward", include_backPropagation=True,
                        print_detailed=False, print_results=False, output_as_string=False,
                    )
                    F_inf_pass, _, _ = calculate_flops(
                        detector.model, input_shape=input_shape,
                        forward_mode="forward", include_backPropagation=False,
                        print_detailed=False, print_results=False, output_as_string=False,
                    )

                    dl_cache[dl_model_name] = dict(
                        F_train_pass=F_tr_pass,
                        F_infer_pass=F_inf_pass,
                        epochs=int(train_cfg.get('epochs', 1)),
                        base_train_w=loader.train_window_len,
                        base_test_w=loader.test_window_len,
                        base_train_len=ds_meta[dataset]["train_len"],
                        base_test_len=ds_meta[dataset]["test_len"],
                    )
                except Exception as e:
                    print(f"[DL][SCALING][ERROR] dataset={dataset}, model={dl_model_name}: {e}")
                    dl_cache[dl_model_name] = None
        except Exception as e:
            print(f"[DL][SCALING][ERROR] dataset={dataset}: {e}")

        base_train_ds = TimeSeriesDataset(dataset, train=True)
        base_test_ds  = TimeSeriesDataset(dataset, train=False)
        ml_cfg_root = ModelConfig("ML_calflops").get_param(dataset)

        train_targets, test_targets = target_map[dataset]

        for idx in range(len(percents)):
            n_tr = train_targets[idx]
            n_te = test_targets[idx]
            d = base_train_ds.input_dim

            for dl_model_name in dl_model_names:
                cache = dl_cache.get(dl_model_name, None)
                if cache is None:
                    rows.append({"dataset": dataset, "model_name": dl_model_name,
                                 "percent": percents[idx],
                                 "train_size": n_tr, "test_size": n_te,
                                 "train_flops": np.nan, "inference_flops": np.nan, "total_flops": np.nan})
                    continue

                Ftr_pass = cache["F_train_pass"]
                Finf_pass = cache["F_infer_pass"]
                epochs   = cache["epochs"]
                base_tw  = cache["base_train_w"]; base_tew = cache["base_test_w"]
                base_tr_len = cache["base_train_len"]; base_te_len = cache["base_test_len"]

                s_tr = (n_tr / base_tr_len) if base_tr_len > 0 else 0.0
                s_te = (n_te / base_te_len) if base_te_len > 0 else 0.0
                scaled_train_w = _safe_int(max(0, base_tw  * s_tr))
                scaled_test_w  = _safe_int(max(0, base_tew * s_te))

                train_flops_full = Ftr_pass * scaled_train_w * epochs
                inference_flops  = Finf_pass * scaled_test_w

                rows.append({
                    "dataset": dataset, "model_name": dl_model_name,
                    "percent": percents[idx],
                    "train_size": n_tr, "test_size": n_te,
                    "train_flops": train_flops_full,
                    "inference_flops": inference_flops,
                    "total_flops": train_flops_full + inference_flops
                })

            for ml_model_name in ml_model_names:
                try:
                    cfg = dict(ml_cfg_root[ml_model_name])

                    if ml_model_name == "IForest":
                        psi = max(2, min(cfg["max_samples"], n_tr))
                        args = dict(num_train=n_tr, num_inference=n_te,
                                    n_estimators=cfg["n_estimators"], max_samples=psi)
                        flops_train, flops_inf = ML_Flops.FLOPs_IForest(**args)

                    elif ml_model_name == "HSTree":
                        psi = min(cfg["ref_window_size"], n_tr)
                        args = dict(num_train=n_tr, num_inference=n_te,
                                    n_estimators=cfg["n_estimators"],
                                    max_depth=cfg["max_depth"], ref_window_size=psi)
                        flops_train, flops_inf = ML_Flops.FLOPs_HSTree(**args)

                    elif ml_model_name == "CBLOF":
                        tr_subset = _slice_dataset(base_train_ds, n_tr)
                        te_subset = _slice_dataset(base_test_ds,  n_te)
                        model_params = cfg.get('model_params', cfg)
                        args = dict(train_set=tr_subset, test_set=te_subset, model_params=model_params)
                        flops_train, flops_inf = ML_Flops.FLOPs_CBLOF(**args)

                    else:
                        args = {"num_train": n_tr, "num_inference": n_te, "input_dim": d}
                        if ml_model_name == "HBOS":
                            args.update(bins=cfg["bins"])
                            flops_train, flops_inf = ML_Flops.FLOPs_HBOS(**args)
                        elif ml_model_name == "LODA":
                            args.update(n_random_cuts=cfg["n_random_cuts"], bins=cfg["bins"])
                            flops_train, flops_inf = ML_Flops.FLOPs_LODA(**args)
                        elif ml_model_name == "ABOD":
                            k_eff = min(cfg["n_neighbors"], max(1, n_tr - 1), max(1, n_te - 1))
                            args.update(n_neighbors=k_eff)
                            if n_tr < 2 or n_te < 2:
                                flops_train, flops_inf = (np.nan, np.nan)
                            else:
                                flops_train, flops_inf = ML_Flops.FLOPs_ABOD(**args)
                        elif ml_model_name == "PCA":
                            args.update(n_components=cfg["n_components"])
                            flops_train, flops_inf = ML_Flops.FLOPs_PCA(**args)
                        elif ml_model_name == "LOF":
                            k = cfg["n_neighbors"]
                            k_eff = min(k, max(1, n_tr-1))
                            args.update(n_neighbors=k_eff)
                            if n_tr < 2 or n_te < 2:
                                flops_train, flops_inf = (np.nan, np.nan)
                            else:
                                flops_train, flops_inf = ML_Flops.FLOPs_LOF(**args)
                        elif ml_model_name == "Hotelling":
                            flops_train, flops_inf = ML_Flops.FLOPs_Hotelling(n_tr, n_te, d)
                        else:
                            raise KeyError(f"Unsupported ML model for scaling: {ml_model_name}")

                    rows.append({
                        "dataset": dataset, "model_name": ml_model_name,
                        "percent": percents[idx],
                        "train_size": n_tr, "test_size": n_te,
                        "train_flops": flops_train,
                        "inference_flops": flops_inf,
                        "total_flops": (flops_train + flops_inf) if (np.isfinite(flops_train) and np.isfinite(flops_inf)) else np.nan
                    })

                except Exception as e:
                    print(f"[ML][SCALING][ERROR] dataset={dataset}, model={ml_model_name}, p={percents[idx]}: {e}")
                    rows.append({
                        "dataset": dataset, "model_name": ml_model_name,
                        "percent": percents[idx],
                        "train_size": n_tr, "test_size": n_te,
                        "train_flops": np.nan, "inference_flops": np.nan, "total_flops": np.nan
                    })

    df_all = pd.DataFrame(rows)
    csv_path = os.path.join(save_dir, save_csv_name)
    df_all.to_csv(csv_path, index=False)
    print(f"[CSV] saved: {csv_path}")

    return df_all, ds_meta, target_map

def plot_one_dataset_row(
    df_all: pd.DataFrame,
    ds_meta: dict,
    dataset: str,
    save_dir="./analysis/results",
    logy=True):
    
    if dataset not in ds_meta:
        raise ValueError(f"Unknown dataset: {dataset}")

    df_ds = df_all[df_all["dataset"] == dataset].copy()
    df_ds = df_ds[np.isfinite(df_ds["train_flops"]) | np.isfinite(df_ds["inference_flops"])]

    x_tr_min, x_tr_max = int(df_ds["train_size"].min()), int(df_ds["train_size"].max())
    x_te_min, x_te_max = int(df_ds["test_size"].min()),  int(df_ds["test_size"].max())

    def _ylim(series1, series2, logy):
        y_pos = (pd.concat([series1, series2]) / 1e9)
        y_pos = y_pos[(y_pos > 0) & np.isfinite(y_pos)]
        if len(y_pos) == 0:
            return (1.0, 1.0) if logy else (0.0, 1.0)
        y_min = float(y_pos.min()); y_max = float(y_pos.max())
        return ((y_min/1.5, y_max*1.2) if logy else (0.0, y_max*1.05))

    y_min, y_max = _ylim(df_ds["train_flops"], df_ds["inference_flops"], logy)

    unique_models = sorted(df_ds["model_name"].dropna().unique())

    fig, (ax_tr, ax_in) = plt.subplots(1, 2, figsize=(18, 7.5), sharex=False, sharey=False, gridspec_kw={'wspace': 0.1})

    def _fmt(ax, title_suffix, xlim, ylim, show_ylabel=False, xlabel="Data size (instances)"):
        ax.set_xlim(*xlim)
        if logy: ax.set_yscale("log")
        ax.set_ylim(*ylim)
        ax.minorticks_off()
        ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.6)
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.2f}" if x <= 1.0 else f"{int(x)}"))
        ax.tick_params(axis='both', which='major', labelsize=18)
        ax.tick_params(axis='x', labelrotation=15)
        dim = ds_meta[dataset]["input_dim"]
        ax.set_title(f"{title_suffix} — {dataset} (dim={dim})", pad=8, fontdict={'fontsize': 20})
        ax.set_xlabel(xlabel, fontdict={'fontsize': 20})
        if show_ylabel: ax.set_ylabel("GFLOPs", fontdict={'fontsize': 20})
        ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)

    def _plot(ax, ycol, xcol):
        for model in unique_models:
            sub = df_ds[df_ds["model_name"] == model].sort_values(xcol)
            if sub.empty:
                continue
            st = MODEL_STYLES.get(model, {
                "color": None, "marker": 'o', "linestyle": '-', "linewidth": 1.6, "label": PRETTY_LABEL.get(model, model)
            })
            ax.plot(
                sub[xcol].values,
                (sub[ycol].values / 1e9),
                label=st["label"],
                linestyle=st["linestyle"],
                linewidth=st["linewidth"],
                marker=st["marker"],
                ms=10,
                color=st["color"],
                alpha=0.95
            )

    _plot(ax_tr, "train_flops", "train_size")
    _plot(ax_in, "inference_flops", "test_size")

    _fmt(ax_tr, "Train FLOPs", (x_tr_min, x_tr_max), (y_min, y_max), show_ylabel=True,  xlabel="Data size (instances)")
    _fmt(ax_in, "Inference FLOPs", (x_te_min, x_te_max), (y_min, y_max), show_ylabel=False, xlabel="Data size (instances)")

    handles, labels = [], []
    for ax in (ax_tr, ax_in):
        h, l = ax.get_legend_handles_labels(); handles += h; labels += l
    seen = set(); h_uniq = []; l_uniq = []
    for h, l in zip(handles, labels):
        if l not in seen: seen.add(l); h_uniq.append(h); l_uniq.append(l)

    fig.subplots_adjust(left=0.10, right=0.98, top=0.78, bottom=0.16, wspace=0.12)
    fig.legend(h_uniq, l_uniq, loc="lower center", bbox_to_anchor=(0.53, 0.85),
               ncol=4, frameon=False, columnspacing=1.2, handletextpad=0.6, borderaxespad=0.2, fontsize=20)
    ax_in.tick_params(labelleft=False)

    os.makedirs(save_dir, exist_ok=True)
    out = os.path.join(save_dir, f"scalability_{dataset}_row.pdf")
    fig.savefig(out, dpi=600, bbox_inches="tight", pad_inches=0.20)
    plt.close(fig)
    print(f"[PLOT] saved: {out}")


# ==========================================================================================================================


datasets   = ['SMD', 'SMAP', 'MSL', 'SWaT', 'WADI', 'PSM']
dl_model_names = ['USAD', 'DAGMM', 'LUAD', 'lstmAE', 'lstmVAE', 'OmniAnomaly', 'DeepSVDD']
ml_model_names = ['HBOS', 'LODA', 'ABOD', 'PCA', 'LOF', 'Hotelling', 'IForest', 'HSTree', 'CBLOF']

dl_model_names_select = ['LUAD', 'lstmAE', 'OmniAnomaly']
ml_model_names_select = ['HBOS', 'LOF', 'IForest', 'HSTree', 'Hotelling']


scalability_df_min_all, ds_meta_min_all, target_lengths_min_all = run_scalability_experiment(
    datasets=datasets,
    ml_model_names=ml_model_names,
    dl_model_names=dl_model_names,
    percents=list(range(10, 101, 10)),
    save_dir="./analysis/results",
    save_csv_name="scalability_flops_min_all.csv",
    size_mode="min"
)

scalability_df_min_select, ds_meta_min_select, target_lengths_min_select = run_scalability_experiment(
    datasets=datasets,
    ml_model_names=ml_model_names_select,
    dl_model_names=dl_model_names_select,
    percents=list(range(10, 101, 10)),
    save_dir="./analysis/results",
    save_csv_name="scalability_flops_min_select.csv",
    size_mode="min"
)


for ds in datasets:
    plot_one_dataset_row(
        df_all=scalability_df_min_all,
        ds_meta=ds_meta_min_all,
        dataset=ds,
        save_dir="./analysis/results/scalability_final_all",
        logy=True
    )

for ds in datasets:
    plot_one_dataset_row(
        df_all=scalability_df_min_select,
        ds_meta=ds_meta_min_select,
        dataset=ds,
        save_dir="./analysis/results/scalability_final_select",
        logy=True
    )