import pandas as pd
import os
import numpy as np
import argparse
import seaborn as sns
from matplotlib import pyplot as plt

METRIC_INFO = {
    "Shape": {"name": "similarity.shape", "maximize": True},
    "Trend": {"name": "similarity.trend", "maximize": True},
    "Detection": {"name": "classifiertest.auc", "maximize": False},
    "DCR": {"name": "dcr.score", "maximize": True},
    "MLE (R2)": {"name": "mle.train-synthetic-test-real.r2", "maximize": True},
    "MLE (AUC)": {"name": "mle.train-synthetic-test-real.auc", "maximize": True},
    "MLE (F1)": {"name": "mle.train-synthetic-test-real.f1", "maximize": True},
    "MLE (RMSE)": {"name": "mle.train-synthetic-test-real.rmse", "maximize": False},
    "MLE (Real, R2)": {"name": "mle.train-real-test-real.r2", "maximize": True},
    "MLE (Real, AUC)": {"name": "mle.train-real-test-real.auc", "maximize": True},
    "MLE (Real, F1)": {"name": "mle.train-real-test-real.f1", "maximize": True},
    "MLE (Real, RMSE)": {
        "name": "mle.train-real-test-real.rmse",
        "maximize": False,
    },
    "Inference Time": {"name": "inference_time", "maximize": False},
    "Training Time": {"name": "training_time", "maximize": False},
    "aPrecision": {"name": "alphaprecision.naive.score", "maximize": True},
    "bRecall": {"name": "betacoverage.naive.score", "maximize": True},
}

SMALL_DS = [
    "iris",
    "wine",
    "california",
    "parkinsons",
    "climate_model_crashes",
    "concrete_compression",
    "yacht_hydrodynamics",
    "airfoil_self_noise",
    "connectionist_bench_sonar",
    "ionosphere",
    "qsar_biodegradation",
    "seeds",
    "glass",
    "ecoli",
    "yeast",
    "libras",
    "planning_relax",
    "blood_transfusion",
    "breast_cancer_diagnostic",
    "connectionist_bench_vowel",
    "concrete_slump",
    "wine_quality_red",
    "wine_quality_white",
    "bean",
    "tictactoe",
    "congress",
    "car",
    # "higgs",
]
LARGE_DS = [
    "churn",
    "nmes",
    "lending",
    "adult",
    "default",
    "bank",
    "beijing",
    "news",
    "diabetes",
    "covertype",
    "acsincome",
]

MODEL_MAPPER = {
    "xgenboost_diffusion_vddpm": "XGenB (V-DDPM)",
    "xgenboost_diffusion_vddim": "XGenB (V-DDIM)",
    "xgenboost_diffusion_xddim": "XGenB (X-DDIM)",
    "xgenboost_diffusion_xddpm": "XGenB (X-DDPM)",
    "xgenboost_ar": "XGenB-AR",
    "xgenboost_multiclass": "XGenB-MC",
    "ctgan": "CTGAN",
    "tvae": "TVAE",
    "tabddpm": "TabDDPM",
    "tabsyn": "TabSyn",
    "unmaskingtrees": "UT",
    "forestdiffusion_flow": "FF",
    "forestdiffusion": "FD",
    "smote": "SMOTE",
    "arf": "ARF",
}


def model_mapper(model: str):
    return MODEL_MAPPER[model.lower()]


METRIC_ORDER = [
    "Shape",
    "Trend",
    "Detection",
    "aPrecision",
    "bRecall",
    "MLE",
    "MLE (R2)",
    "MLE (AUC)",
    "MLE (F1)",
    "MLE (RMSE)",
    "DCR",
    "Training Time",
    "Inference Time",
]

MODEL_ORDER = [
    "Real",
    "SMOTE",
    "ARF",
    "UT",
    "FD",
    "FF",
    "CTGAN",
    "TVAE",
    "TabDDPM",
    "TabSyn",
    "XGenB-AR",
    "XGenB-MC",
    "XGenB (X-DDPM)",
    "XGenB (V-DDPM)",
    "XGenB (X-DDIM)",
    "XGenB (V-DDIM)",
]


def latex_bold_best(
    df: pd.DataFrame,
    axis: int = 0,
    best: str = "max",
    precision: int = 3,
    underline_second_best: bool = True,
):
    """
    Returns a DataFrame with LaTeX-formatted strings where best values
    are boldfaced and (optionally) second-best values are underlined.
    """

    if best not in ("max", "min"):
        raise ValueError("Parameter 'best' must be 'max' or 'min'.")

    df = df.astype(float)

    # Helper to get best and second-best (unique)
    def best_and_second(values):
        uniq = np.unique(values)
        uniq = uniq[~np.isnan(uniq)]
        if len(uniq) < 2:
            return uniq[0], None
        uniq = np.sort(uniq)
        if best == "max":
            return uniq[-1], uniq[-2]
        else:
            return uniq[0], uniq[1]

    # Formatting
    def format_entry(x, is_best, is_second):
        if is_best:
            return f"\\mathbf{{{x:.{precision}f}}}"
        if underline_second_best and is_second:
            return f"\\underline{{{x:.{precision}f}}}"
        return f"{x:.{precision}f}"

    formatted_df = df.copy()

    if axis == 0:  # best per column
        for col in df.columns:
            b, sb = best_and_second(df[col].values)
            formatted_df[col] = [
                format_entry(v, v == b, sb is not None and v == sb) for v in df[col]
            ]

    elif axis == 1:  # best per row
        for i in df.index:
            b, sb = best_and_second(df.loc[i].values)
            formatted_df.loc[i] = [
                format_entry(v, v == b, sb is not None and v == sb) for v in df.loc[i]
            ]
    else:
        raise ValueError("Axis must be 0 (columns) or 1 (rows).")

    return formatted_df  # .to_latex(escape=False)


def load_results(dir: str):
    df = pd.DataFrame()

    for ds in SMALL_DS + LARGE_DS:
        for model in MODEL_MAPPER.keys():
            filename = f"{ds}_{model}.csv"
            try:
                df_temp = pd.read_csv(os.path.join(dir, filename))
            except FileNotFoundError:
                continue
            df_temp["dataset"] = ds.replace("_", " ")
            df_temp["model"] = model
            df = pd.concat([df, df_temp])

    return df


def get_rank_table(df: pd.DataFrame, metric_info: dict):
    data_ori = df.copy()
    data_ori = data_ori[data_ori.model.isin(MODEL_ORDER)]
    data_ori = data_ori[data_ori.metric.isin(METRIC_ORDER)]

    data = data_ori[~data_ori.metric.str.contains("Training Time")]
    data = data[~data.metric.str.contains("F1")]
    data = data[~data.metric.str.contains("RMSE")]
    data = data[~data.metric.str.contains("Real,")]

    data.metric[data.metric == "MLE (R2)"] = "MLE"
    data.metric[data.metric == "MLE (AUC)"] = "MLE"

    avg = (
        data.groupby(["dataset", "model", "metric"])
        .mean(numeric_only=True)
        .reset_index()
    )

    # get ranks over datasets
    avg["rank"] = avg.groupby(["dataset", "metric"])["value"].rank(ascending=False)
    # reverse ranks for metrics which need it
    mapper = {v["name"]: k for k, v in metric_info.items()}
    for v in metric_info.values():
        if not v["maximize"]:

            avg.loc[avg.metric == mapper[v["name"]], "rank"] = (
                avg["rank"].max() + 1 - avg.loc[avg.metric == mapper[v["name"]], "rank"]
            )

    # add worst rank for missing values
    for met in np.unique(avg.metric):
        for d in np.unique(avg.dataset):
            for mod in np.unique(avg.model):
                if (
                    avg.loc[
                        (avg.metric == met) & (avg.dataset == d) & (avg.model == mod),
                        "value",
                    ].to_list()
                    == []
                ):
                    avg.loc[
                        (avg.metric == met) & (avg.dataset == d) & (avg.model == mod),
                        "rank",
                    ] = len(np.unique(df["model"]))

    avg_ranks = (
        avg.groupby(["model", "metric"])["rank"].mean(numeric_only=True).reset_index()
    )
    std_ranks = (
        avg.groupby(["model", "metric"])["rank"].std(numeric_only=True).reset_index()
    )
    avg_ranks = avg_ranks.pivot(index="model", columns="metric", values="rank")
    avg_ranks = avg_ranks.round(3).map(lambda x: f"{x:.3f}")
    avg_ranks = latex_bold_best(avg_ranks, axis=0, best="min")
    std_ranks = std_ranks.pivot(index="model", columns="metric", values="rank")
    std_ranks = std_ranks.round(3).map(lambda x: f"{x:.3f}")

    # now add training time
    tt = data_ori[data_ori.metric.str.contains("Training Time")]
    tt["rank"] = tt.groupby(["dataset"])["value"].rank(ascending=True)

    for d in np.unique(tt.dataset):
        for mod in np.unique(tt.model):
            if (
                tt.loc[
                    (tt.dataset == d) & (tt.model == mod),
                    "value",
                ].to_list()
                == []
            ):
                tt.loc[
                    (tt.dataset == d) & (tt.model == mod),
                    "rank",
                ] = len(np.unique(data_ori["model"]))

    avg_tt_ranks = tt.groupby(["model"])["rank"].mean(numeric_only=True).reset_index()
    avg_tt_ranks["metric"] = "Training Time"
    avg_tt_ranks = avg_tt_ranks.pivot(index="model", columns="metric", values="rank")
    avg_tt_ranks = avg_tt_ranks.round(3).map(lambda x: f"{x:.3f}")
    avg_tt_ranks = latex_bold_best(avg_tt_ranks, axis=0, best="min")
    std_tt_ranks = tt.groupby(["model"])["rank"].std(numeric_only=True).reset_index()
    std_tt_ranks["metric"] = "Training Time"
    std_tt_ranks = std_tt_ranks.pivot(index="model", columns="metric", values="rank")
    std_tt_ranks = std_tt_ranks.round(3).map(lambda x: f"{x:.3f}")

    avg_ranks = pd.concat([avg_ranks, avg_tt_ranks.astype(str)], axis=1)
    std_ranks = pd.concat([std_ranks, std_tt_ranks], axis=1)

    full_table = "$" + avg_ranks.astype(str) + "_{\pm " + std_ranks.astype(str) + "}$"

    full_table = full_table[[x for x in METRIC_ORDER if x in avg_ranks.columns]]
    full_table = full_table.loc[[x for x in MODEL_ORDER if x in avg_ranks.index]]

    full_table.index.name = None

    return full_table


def get_metric_table(df, metric_info: dict, metric_name: str):
    metric_info = metric_info[metric_name]

    data = df.copy()

    # TBD: add Real for MLE metrics
    if metric_name.startswith("MLE"):
        for m in ["R2", "AUC", "F1", "RMSE"]:
            data.loc[data.metric == f"MLE (Real, {m})", "model"] = "Real"
            data.loc[data.metric == f"MLE (Real, {m})", "metric"] = f"MLE ({m})"

    data = data[data.model.isin(MODEL_ORDER)]
    data = data[data.metric.isin(METRIC_ORDER)]

    data = data[data.metric == metric_name]
    data = data.drop(columns=["metric"])
    avg = data.groupby(["dataset", "model"]).mean(numeric_only=True).reset_index()
    avg = avg.pivot(index="model", columns="dataset", values="value")
    avg = avg.round(3).map(lambda x: f"{x:.3f}")

    avg_no_real = avg[avg.index != "Real"]
    avg_no_real = latex_bold_best(
        avg_no_real, axis=0, best="max" if metric_info["maximize"] else "min"
    )
    avg = pd.concat([avg_no_real, avg[avg.index == "Real"]], axis=0)

    std = data.groupby(["dataset", "model"]).std(numeric_only=True).reset_index()
    std = std.pivot(index="model", columns="dataset", values="value")

    std = std.round(3).map(lambda x: f"{x:.3f}")

    if metric_name == "Training Time":
        std = pd.DataFrame("", index=avg.index, columns=avg.columns)
        full_table = avg.astype(str)
        full_table = "$" + full_table.astype(str) + "$"
    else:
        full_table = "$" + avg.astype(str) + "_{\pm " + std.astype(str) + "}$"

    full_table = full_table.mask(
        full_table.apply(lambda col: col.str.contains("nan", regex=False)), "-"
    )
    full_table = full_table.reset_index(drop=False)
    full_table = full_table.set_index("model")

    full_table = full_table.loc[[x for x in MODEL_ORDER if x in full_table.index]]

    full_table = full_table.T

    full_table.index.name = None

    return full_table


def print_rank_table(df: pd.DataFrame, metric_info: dict):
    full_table = get_rank_table(df, metric_info)
    full_table = full_table.to_latex(escape=False)
    full_table = (
        "\\begin{table*}[!htbp]\n"
        f"\\caption{{Metric ranks (mean and std)}}\n"
        "\\resizebox{\\linewidth}{!}{" + full_table + "}\n\\end{table*}"
    )
    print(full_table)


def print_metric_tables(df: pd.DataFrame, metric_info: dict, metric: str = "all"):
    if metric == "all":
        metrics = METRIC_ORDER
    else:
        metrics = [metric]
    for m in metrics:

        if m == "MLE":
            continue

        full_table = get_metric_table(df, metric_info, m)
        full_table = full_table.to_latex(escape=False)
        full_table = (
            "\\begin{table}[!htbp]\n"
            f"\\caption{{{m}}}\n"
            "\\resizebox{\\linewidth}{!}{" + full_table + "}\n\\end{table}"
        )
        print(full_table)


def print_fun(filepath: str, table_type: str = "rank"):
    df = load_results(filepath)

    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)

    # convert training times to minutes
    df.loc[df.metric == "Training Time", "value"] = (
        df.loc[df.metric == "Training Time", "value"] / 60
    )

    df = df.dropna(subset=["metric"])

    df.model = df.model.apply(model_mapper)

    # ensure SMOTE has no training times
    mask = (df.model == "SMOTE") & (df.metric == "Training Time")
    df.loc[mask, "value"] = np.nan

    if table_type == "rank":
        print_rank_table(df, METRIC_INFO)
    elif table_type == "metric":
        print_metric_tables(df, METRIC_INFO)
    else:
        raise ValueError(f"Invalid table type: {table_type}")


def print_cat_merge_ablation(filepath: str):

    dss = ["lending", "adult", "diabetes", "acsincome"]
    df = pd.DataFrame()
    for ds in dss:
        df_cat = pd.read_csv(f"{filepath}/{ds}_xgenboost_ar_cat.csv")
        df_cat["dataset"] = ds
        df_cat["model"] = "Naive"
        df_base = pd.read_csv(f"{filepath}/{ds}_xgenboost_ar.csv")
        df_base["dataset"] = ds
        df_base["model"] = "Clustering"
        df = pd.concat([df, df_cat, df_base])

    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)

    df = df.dropna(subset=["metric"])
    df.loc[df.metric.str.lower().str.startswith("mle"), "metric"] = "MLE"

    avg = (
        df.groupby(["dataset", "model", "metric"]).mean(numeric_only=True).reset_index()
    )
    std = (
        df.groupby(["dataset", "model", "metric"]).std(numeric_only=True).reset_index()
    )
    avg = avg.pivot(index="metric", columns=["dataset", "model"], values="value")
    std = std.pivot(index="metric", columns=["dataset", "model"], values="value")
    order = [x for x in METRIC_ORDER if x in avg.index]
    avg = avg.loc[order]
    std = std.loc[order]
    avg = avg.round(3).map(lambda x: f"{x:.3f}")
    std = std.round(3).map(lambda x: f"{x:.3f}")

    full_table = "$" + avg.astype(str) + "_{\pm " + std.astype(str) + "}$"
    print(full_table.to_latex(escape=False))


def print_sampling_ablation(filepath: str):

    binning = ["q", "u", "k"]
    sampling = ["eqf", "u"]
    df = pd.DataFrame()
    for ds in LARGE_DS:
        for b in binning:
            for s in sampling:
                prefix = f"{filepath}/{ds}_xgenboost_ar"
                postfix = f"_{b}_{s}" if not (b == "q" and s == "eqf") else ""
                df_temp = pd.read_csv(f"{prefix}{postfix}.csv")
                df_temp["dataset"] = ds
                df_temp["binning"] = b
                df_temp["sampling"] = s
                df = pd.concat([df, df_temp])

    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)
    df = df.dropna(subset=["metric"])

    df = df[~df.metric.str.contains("F1")]
    df = df[~df.metric.str.contains("RMSE")]
    df = df[~df.metric.str.contains("Real,")]

    df.loc[df.metric.isin(["MLE (R2)", "MLE (AUC)"]), "metric"] = "MLE"
    # average per dataset
    avg_per_dataset = (
        df.groupby(["dataset", "binning", "sampling", "metric"])
        .mean(numeric_only=True)
        .reset_index()
    )

    # average over all datasets
    avg_over_datasets = (
        avg_per_dataset.groupby(["binning", "sampling", "metric"])
        .mean(numeric_only=True)
        .reset_index()
    )

    avg_over_datasets = avg_over_datasets.pivot(
        index="metric", columns=["binning", "sampling"], values="value"
    )
    avg_over_datasets = avg_over_datasets.round(3).map(lambda x: f"{x:.3f}")
    avg_over_datasets = latex_bold_best(avg_over_datasets, axis=1, best="max")
    std_over_datasets = (
        avg_per_dataset.groupby(["binning", "sampling", "metric"])
        .std(numeric_only=True)
        .reset_index()
    )

    std_over_datasets = std_over_datasets.pivot(
        index="metric", columns=["binning", "sampling"], values="value"
    )
    std_over_datasets = std_over_datasets.round(3).map(lambda x: f"{x:.3f}")

    full_table = (
        "$"
        + avg_over_datasets.astype(str)
        + "_{\pm "
        + std_over_datasets.astype(str)
        + "}$"
    )
    full_table = full_table.reindex(columns=["u", "eqf"], level="sampling")
    full_table = full_table.reindex(columns=["q", "u", "k"], level="binning")
    full_table = full_table.loc[
        ["Shape", "Trend", "Detection", "aPrecision", "bRecall", "MLE", "DCR"]
    ]
    print(full_table.to_latex(escape=False))


def print_diffusion_ablation(filepath: str):

    objective = ["x", "v"]
    sampler = ["ddpm", "ddim"]
    df = pd.DataFrame()
    for ds in SMALL_DS:
        for o in objective:
            for s in sampler:
                prefix = f"{filepath}/{ds}_xgenboost_diffusion"
                postfix = f"_{o}{s}"
                df_temp = pd.read_csv(f"{prefix}{postfix}.csv")
                df_temp["dataset"] = ds
                df_temp["objective"] = o
                df_temp["sampler"] = s
                df = pd.concat([df, df_temp])

    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)
    df = df.dropna(subset=["metric"])

    df = df[~df.metric.str.contains("F1")]
    df = df[~df.metric.str.contains("RMSE")]
    df = df[~df.metric.str.contains("Real,")]

    df.loc[df.metric.isin(["MLE (R2)", "MLE (AUC)"]), "metric"] = "MLE"
    # average per dataset
    avg_per_dataset = (
        df.groupby(["dataset", "objective", "sampler", "metric"])
        .mean(numeric_only=True)
        .reset_index()
    )

    # average over all datasets
    avg_over_datasets = (
        avg_per_dataset.groupby(["objective", "sampler", "metric"])
        .mean(numeric_only=True)
        .reset_index()
    )

    avg_over_datasets = avg_over_datasets.pivot(
        index="metric", columns=["sampler", "objective"], values="value"
    )
    avg_over_datasets = avg_over_datasets.round(3).map(lambda x: f"{x:.3f}")
    avg_over_datasets = latex_bold_best(avg_over_datasets, axis=1, best="max")
    std_over_datasets = (
        avg_per_dataset.groupby(["objective", "sampler", "metric"])
        .std(numeric_only=True)
        .reset_index()
    )

    std_over_datasets = std_over_datasets.pivot(
        index="metric", columns=["sampler", "objective"], values="value"
    )
    std_over_datasets = std_over_datasets.round(3).map(lambda x: f"{x:.3f}")

    full_table = (
        "$"
        + avg_over_datasets.astype(str)
        + "_{\pm "
        + std_over_datasets.astype(str)
        + "}$"
    )
    full_table = full_table.reindex(columns=["ddpm", "ddim"], level="sampler")
    full_table = full_table.reindex(columns=["x", "v"], level="objective")
    full_table = full_table.loc[
        ["Shape", "Trend", "Detection", "aPrecision", "bRecall", "MLE", "DCR"]
    ]
    print(full_table.to_latex(escape=False))


def print_ar_ablation(
    filepath: str,
    tie_eps_pct: float = 0.0,  # treat |pct| <= eps as tie (in percentage points)
):
    df = pd.DataFrame()
    for ds in LARGE_DS:
        mc = pd.read_csv(f"{filepath}/{ds}_xgenboost_multiclass.csv")
        ar = pd.read_csv(f"{filepath}/{ds}_xgenboost_ar.csv")
        ar["dataset"] = ds
        ar["model"] = "AR"
        mc["dataset"] = ds
        mc["model"] = "MC"
        df = pd.concat([df, ar, mc], ignore_index=True)

    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)
    df = df.dropna(subset=["metric"])
    df = df[~df.metric.str.contains("F1")]
    df = df[~df.metric.str.contains("RMSE")]
    df = df[~df.metric.str.contains("Real,")]
    df.loc[df.metric.isin(["MLE (R2)", "MLE (AUC)"]), "metric"] = "MLE"

    avg = df.groupby(["dataset", "model", "metric"]).mean(numeric_only=True)[["value"]]
    wide = avg["value"].unstack("model")  # columns: AR, MC
    wide = wide.dropna(subset=["AR", "MC"])

    lower_is_better = {"Detection", "Inference Time", "Training Time"}

    rel_impr = (wide["AR"] - wide["MC"]) / wide["AR"]

    # flip for HIGHER-is-better so that + means UT better
    is_lower = wide.index.get_level_values("metric").isin(lower_is_better)
    rel_impr.loc[~is_lower] *= -1

    pct_impr = 100.0 * rel_impr  # + => UT better, - => UT worse
    pct_impr.name = "pct_improvement"

    # Win/Tie/Loss per (dataset, metric)
    # tie if |pct| <= tie_eps_pct
    win_mask = pct_impr > tie_eps_pct
    loss_mask = pct_impr < -tie_eps_pct
    tie_mask = ~(win_mask | loss_mask)

    def _mean_or_nan(x):
        return float(np.mean(x)) if len(x) else np.nan

    # Build table per metric
    rows = []
    for metric, s in pct_impr.groupby(level="metric"):
        s_vals = s.values

        wins = s[win_mask.loc[s.index]]
        losses = s[loss_mask.loc[s.index]]

        n_win = int(
            (np.abs(s_vals) > tie_eps_pct).sum() and (s_vals > tie_eps_pct).sum()
        )  # safe but redundant
        n_loss = int((s_vals < -tie_eps_pct).sum())
        n_tie = int((np.abs(s_vals) <= tie_eps_pct).sum())

        win_mean = _mean_or_nan(wins.values)  # positive
        loss_mean = _mean_or_nan(
            (-losses).values
        )  # report magnitude as positive "reduction" on loss side

        rows.append(
            {
                "metric": metric,
                "W": n_win,
                "W_mean_%": win_mean,
                "T": n_tie,
                "L": n_loss,
                "L_mean_%": loss_mean,
                "N": n_win + n_tie + n_loss,
            }
        )

    out = pd.DataFrame(rows).set_index("metric")

    # Order rows
    order = [x for x in METRIC_ORDER if x in out.index]
    out = out.loc[order] if order else out.sort_index()

    # Pretty formatting:
    # Win column: "k (m%)" where m is mean over wins
    # Loss column: "k (m%)" where m is mean deterioration magnitude over losses
    def fmt_count_mean(k, m):
        if k == 0 or np.isnan(m):
            return f"{k}"
        return f"{k} ({m:.2f}\\%)"

    win_col = [fmt_count_mean(int(k), m) for k, m in zip(out["W"], out["W_mean_%"])]
    loss_col = [fmt_count_mean(int(k), m) for k, m in zip(out["L"], out["L_mean_%"])]

    pretty = pd.DataFrame(
        {
            "Win": win_col,
            "Tie": out["T"].astype(int).astype(str),
            "Loss": loss_col,
        },
        index=out.index,
    )

    print(pretty.T.to_latex(escape=False))


def dropout_plot(filepath: str):

    dropouts = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
    df = pd.DataFrame()
    for ds in SMALL_DS:
        for dropout in dropouts:
            dropout_str = str(dropout).replace(".", "")
            path = (
                f"{filepath}/{ds}_xgenboost_diffusion_vddim_dropout_{dropout_str}.csv"
            )
            df_temp = pd.read_csv(path)
            df_temp["dataset"] = ds
            df_temp["dropout"] = dropout
            df = pd.concat([df, df_temp])
    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)
    df = df.dropna(subset=["metric"])
    df = df[~df.metric.str.contains("F1")]
    df = df[~df.metric.str.contains("RMSE")]
    df = df[~df.metric.str.contains("Real,")]
    df = df[~df.metric.str.contains("Training Time")]
    df = df[~df.metric.str.contains("Inference Time")]
    df.loc[df.metric.isin(["MLE (R2)", "MLE (AUC)"]), "metric"] = "MLE"

    order = [x for x in METRIC_ORDER if x in df.metric.unique()]

    df.loc[df.metric == "bRecall", "metric"] = r"$\beta$-Recall"
    df.loc[df.metric == "aPrecision", "metric"] = r"$\alpha$-Precision"

    order = [
        (
            r"$\beta$-Recall"
            if x == "bRecall"
            else r"$\alpha$-Precision" if x == "aPrecision" else x
        )
        for x in order
    ]

    df["metric"] = pd.Categorical(
        df["metric"],
        categories=order,
        ordered=True,
    )

    g = sns.FacetGrid(
        df,
        col="metric",
        # hue="dataset",
        col_wrap=4,
        height=3,
        aspect=1.5,
        sharey=False,
        sharex=True,
    )
    g.map(sns.violinplot, "dropout", "value")

    plt.tight_layout()
    plt.savefig("results/dropout_plot.pdf")
    plt.show()


def dequantization_plot(filepath: str):

    quant = {
        "q": "Quantile",
        "u": "Uniform",
        "k": "KMeans",
    }
    dequant = {
        "eqf": "EQF-sampling",
        "u": "Uniform-sampling",
    }

    df = pd.DataFrame()
    for ds in LARGE_DS:
        for q in quant.keys():
            for d in dequant.keys():

                path = f"{filepath}/{ds}_xgenboost_ar_{q}_{d}.csv"
                df_temp = pd.read_csv(path)
                df_temp["dataset"] = ds
                df_temp["quant"] = quant[q]
                df_temp["dequant"] = dequant[d]
                df = pd.concat([df, df_temp])
    metric_mapper = {v["name"]: k for k, v in METRIC_INFO.items()}
    df.metric = df.metric.map(metric_mapper)
    df = df.dropna(subset=["metric"])
    df = df[~df.metric.str.contains("F1")]
    df = df[~df.metric.str.contains("RMSE")]
    df = df[~df.metric.str.contains("Real,")]
    df = df[~df.metric.str.contains("Training Time")]
    df = df[~df.metric.str.contains("Inference Time")]
    df.loc[df.metric.isin(["MLE (R2)", "MLE (AUC)"]), "metric"] = "MLE"

    order = [x for x in METRIC_ORDER if x in df.metric.unique()]

    df.loc[df.metric == "bRecall", "metric"] = r"$\beta$-Recall"
    df.loc[df.metric == "aPrecision", "metric"] = r"$\alpha$-Precision"

    order = [
        (
            r"$\beta$-Recall"
            if x == "bRecall"
            else r"$\alpha$-Precision" if x == "aPrecision" else x
        )
        for x in order
    ]

    df["metric"] = pd.Categorical(
        df["metric"],
        categories=order,
        ordered=True,
    )

    # keep dequant as the split hue (must be exactly 2 levels)
    df["dequant"] = pd.Categorical(
        df["dequant"], categories=["EQF-sampling", "Uniform-sampling"], ordered=True
    )
    df["quant"] = pd.Categorical(
        df["quant"],
        categories=["Quantile", "Uniform", "KMeans"],
        ordered=True,
    )

    g = sns.FacetGrid(
        df,
        col="metric",
        height=3,
        aspect=1.5,
        col_wrap=4,
        sharey=False,
        sharex=True,
    )

    g.map_dataframe(
        sns.violinplot,
        x="quant",
        y="value",
        hue="dequant",
        palette="bright",
        split=True,
        dodge=True,
        alpha=0.5,
    )

    g.add_legend(bbox_to_anchor=(0.9, 0.2))
    for idx, ax in enumerate(g.axes.flat):
        ax.set_xlabel("")
    plt.tight_layout()
    plt.savefig("results/dequantization_plot.pdf")
    plt.show()


def training_time_comparison(small_path: str, big_path: str):

    def plot_training_time(type: str, ax):
        if type == "small":
            path = small_path
            model = "xgenboost_diffusion_vddim"
            title = "Small Benchmark (XGenB-DF)"
        elif type == "big":
            path = big_path
            model = "xgenboost_ar"
            title = "Big Benchmark (XGenB-AR)"
        else:
            raise ValueError(f"Invalid type: {type}")

        df = load_results(path)

        df = df[(df.model == model) & (df.metric == "training_time")].copy()

        # ensure numeric + drop missing
        df["value"] = pd.to_numeric(df["value"], errors="coerce")
        df = df.dropna(subset=["value"])

        # minutes, integer
        df["value"] = (df["value"] / 60).round(1)

        # aggregate to one row per dataset (mean over repeats/seeds)
        df_plot = df.groupby("dataset", as_index=False)["value"].mean()

        order = df_plot.sort_values("value", ascending=False)["dataset"].tolist()

        g = sns.barplot(
            data=df_plot,
            x="value",
            y="dataset",
            order=order,
            orient="h",
            ax=ax,
            errorbar=None,
            dodge=False,
        )
        for c in g.containers:
            g.bar_label(c, fmt="%.1f")

        sns.despine(ax=ax, bottom=True, left=True)

        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_xticks([])
        ax.set_title(title, fontsize=10, fontweight="bold")

    fig, axs = plt.subplots(2, 1, figsize=(6, 10), sharex=False)
    axs = axs.flatten()
    plot_training_time("small", axs[0])
    plot_training_time("big", axs[1])

    plt.xlabel("Training Time (minutes)")

    plt.tight_layout()
    plt.savefig("results/training_time_comparison.pdf")
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filepath", default="results/big/", type=str)
    parser.add_argument(
        "--table_type",
        default="rank",
        type=str,
        choices=[
            "rank",
            "metric",
            "cat_merge_ablation",
            "sampling_ablation",
            "diffusion_ablation",
            "ar_ablation",
            "dropout_plot",
            "dequantization_plot",
        ],
    )
    args = parser.parse_args()
    if args.table_type == "cat_merge_ablation":
        print_cat_merge_ablation(args.filepath)
    elif args.table_type == "sampling_ablation":
        print_sampling_ablation(args.filepath)
    elif args.table_type == "diffusion_ablation":
        print_diffusion_ablation(args.filepath)
    elif args.table_type == "ar_ablation":
        print_ar_ablation(args.filepath)
    elif args.table_type == "dropout_plot":
        dropout_plot(args.filepath)
    elif args.table_type == "dequantization_plot":
        dequantization_plot(args.filepath)
    else:
        print_fun(args.filepath, args.table_type)
