from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from nn_compression.evaluation import plot_df_cv
from functools import partial
import matplotlib.pyplot as plt
from data_utils.experiments import pareto_front as _pareto_front
import pandas as pd
from cyclopts import App
import yaml
from plot_params import *

app = App()

BASE_PERFORMANCE = {
    "ResNet18": 0.95125,
    "ResNet34": 0.95546875,
    "ResNet50": 0.94125,
    "VGG16": 0.71675,
}

TIMES = {
    "Hessian": 1526,
    "OPTQ-RD Quant.": 153,
    "GPTQ Quant.": 55,
    "GPTQ Vanilla": 1526 + 55,
    "OPTQ-RD": 153 + 1526 + 27 + 5,
    "DeepCABAC Encode": 27,
    "DeepCABAC Decode": 5,
}


pareto_front = partial(
    _pareto_front,
    x_column="entropy_deepcabac",
    y_column="acc",
    mode="over",
    # grid=grid,
    grid="linear",
    nsteps=50,
)
pareto_front_nlp = partial(
    _pareto_front, x_column="entropy_deepcabac", y_column="ppl", mode="under"
)

results_folder = Path("out/publication")
plots_folder = Path("../nn-compression-paper/ICLR 2025 Template/plots")
plot_df = partial(plot_df_cv, key_x="entropy_deepcabac", key_y="acc")
sort = lambda x: x.sort_values("entropy_deepcabac")


def collect_results(force=False):
    path = results_folder / "all_results.csv"
    if path.exists() and not force:
        print("Loading results...")
        return pd.read_csv(path)
    print("Collecting results...")
    dfs = []
    for exp in [
        "main-findings",
        "main-findings-rw",
        "main-findings-blocks",
        "ablation_nbatches_resnet",
        "ablation_nbatches",
        "ablation_train_test_shift",
    ]:
        for f in results_folder.glob(f"{exp}_*"):
            for network in f.glob("*"):
                for dataset in network.glob("*"):
                    for bits in dataset.glob("*"):
                        for csv in bits.glob("*.csv"):
                            df = pd.read_csv(csv)
                            if "sparsity" in df.columns:
                                df = df.assign(sparsity_inv=1 - df.sparsity)
                            if network.stem.startswith("gpt2"):
                                continue  # no experiments for now
                            if csv.stem == "nncodec-results":
                                df = df.assign(
                                    name="nncodec",
                                    layer_name="all",
                                    entropy_deepcabac=df.entropy,
                                    bits=0,
                                    ppl=df.acc,
                                )
                            if (bits / "config.yaml").exists():
                                with open(bits / "config.yaml") as f:
                                    cfg = yaml.safe_load(f)
                            else:
                                assert csv.stem == "nncodec-results"
                                cfg = {}

                            df = df.assign(
                                exp=exp,
                                network=network.stem,
                                dataset=dataset.stem,
                                nbatches=cfg.get("nbatches", 0),
                                train_dataset=(
                                    cfg["train_dataset"]
                                    if cfg.get("train_dataset", None) is not None
                                    else dataset.stem
                                ),
                            )

                            if csv.stem == "gptq":
                                dfs.append(
                                    df.assign(
                                        name="gptq_vanilla", entropy_deepcabac=df.bits
                                    )
                                )
                            if csv.stem == "gptq" and exp == "main-findings":
                                dfs.append(
                                    df.assign(
                                        name="gptq_bz2",
                                        entropy_deepcabac=df.entropy_bz2,
                                    )
                                )
                            dfs.append(df)
    df_all = pd.concat(dfs).reset_index(drop=True)
    df_all.to_csv(results_folder / "all_results.csv")
    return df_all


def label_cv():
    plt.xlabel(AXIS_LABELS_CV[0])
    plt.ylabel(AXIS_LABELS_CV[1])
    return


def save(func):
    def wrapper(*args, **kwargs):
        func(*args, **kwargs)
        plt.savefig(plots_folder / f"{func.__name__}.png")
        plt.savefig(plots_folder / f"{func.__name__}.pdf")

    return wrapper


@save
def figure_ablation_train_test_shift(df):
    plt.figure()
    print("Plotting the performance of OPTQ-RD on different training sets...")
    vgg16 = df.query("network == 'vgg16' & layer_name == 'all'")
    vgg16_coco = vgg16.query(
        "train_dataset == 'coco' & name == 'uniform' & exp == 'ablation_train_test_shift'"
    ).sort_values("entropy_deepcabac")

    vgg16_kodak = vgg16.query(
        "train_dataset == 'kodak' & name == 'uniform' & exp == 'ablation_train_test_shift'"
    ).sort_values("entropy_deepcabac")

    vgg16_imagenet = vgg16.query(
        "network == 'vgg16' & name == 'uniform' & exp == 'main-findings' & bits == 4.0"
    ).sort_values("entropy_deepcabac")
    vgg16_gptq = vgg16.query(
        "network == 'vgg16' & name == 'gptq' & exp == 'ablation_train_test_shift' & train_dataset == 'coco'"
    ).sort_values("entropy_deepcabac")
    vgg16_gptq_imagenet = vgg16.query(
        "network == 'vgg16' & name == 'gptq' & exp == 'main-findings'"
    ).sort_values("entropy_deepcabac")
    vgg16_nnc = vgg16.query("name == 'nncodec'")
    # plt.plot(
    #     vgg16_kodak.entropy_deepcabac,
    #     vgg16_kodak.acc,
    #     DEFAULT_MARKER,
    #     label="OPTQ-RD KODAK",
    # )
    plt.plot(
        vgg16_coco.entropy_deepcabac,
        vgg16_coco.acc,
        label=TRAIN_TEST_SHIFT_LABELS["uniform_coco"],
        **TRAIN_TEST_SHIFT_PARAMS["uniform_coco"],
    )
    plt.plot(
        vgg16_imagenet.entropy_deepcabac,
        vgg16_imagenet.acc,
        label=TRAIN_TEST_SHIFT_LABELS["uniform_imagenet"],
        **TRAIN_TEST_SHIFT_PARAMS["uniform_imagenet"],
    )
    plt.plot(
        vgg16_nnc.entropy_deepcabac,
        vgg16_nnc.acc,
        label=TRAIN_TEST_SHIFT_LABELS["nncodec"],
        **TRAIN_TEST_SHIFT_PARAMS["nncodec"],
    )
    plt.plot(
        vgg16_gptq.entropy_deepcabac,
        vgg16_gptq.acc,
        label=TRAIN_TEST_SHIFT_LABELS["gptq_coco"],
        **TRAIN_TEST_SHIFT_PARAMS["gptq_coco"],
    )
    plt.plot(
        vgg16_gptq_imagenet.entropy_deepcabac,
        vgg16_gptq_imagenet.acc,
        label=TRAIN_TEST_SHIFT_LABELS["gptq_imagenet"],
        **TRAIN_TEST_SHIFT_PARAMS["gptq_imagenet"],
    )
    plt.legend()
    plt.xlim(0, 2)
    label_cv()


@save
def figure_ablation_nbatches(df):
    plt.figure()
    print("Plotting the performance of OPTQ-RD on different training set sizes...")
    vgg16 = df.query(
        "network == 'vgg16' & layer_name == 'all' & exp == 'ablation_nbatches' & train_dataset == 'imagenet'"
    )
    k = 0
    for b, g in vgg16.groupby("nbatches"):
        if b == 50_000:
            continue
        g = pareto_front(sort(g))
        plt.plot(
            g.entropy_deepcabac,
            g.acc,
            label=f"{b * 8} samples",
            **PLOT_PARAMS_NBATCHES[k],  # type: ignore
        )
        k += 1
    plt.legend()
    label_cv()


@save
def figure_main_results_zoomed(df):
    print("Plotting main results...")
    four_panel(
        df,
        "entropy_deepcabac",
        "acc",
        AXIS_LABELS_CV[0],
        AXIS_LABELS_CV[1],
        ["nncodec", "direct_rd", "gptq", "gptq_bz2", "uniform", "alpha_inv_tr"],
        (0, 1.05),
        (-0.05, 1.05),
        include_baseline=True,
        add_nbins=True,
    )


@save
def figure_main_results(df):
    print("Plotting main results...")
    four_panel(
        df,
        "entropy_deepcabac",
        "acc",
        AXIS_LABELS_CV[0],
        AXIS_LABELS_CV[1],
        ["nncodec", "direct_rd", "gptq", "gptq_bz2", "uniform", "alpha_inv_tr"],
        (0, 5.3),
        (-0.05, 1.05),
        include_baseline=True,
    )


@save
def figure_main_results_per_grid(df):
    fig, ax = plt.subplots()
    ax.set_xlabel(AXIS_LABELS_CV[0])
    ax.set_ylabel(AXIS_LABELS_CV[1])
    subset = df.query(
        "name == 'uniform' & network == 'vgg16' & exp == 'main-findings' & bits > 2.23"
    )
    k = 0
    for n, b in subset.groupby("bits"):
        b = b.sort_values("entropy_deepcabac")
        # colors = plt.cm.viridis(np.linspace(0, 1, len(subset.bits.unique())))

        ax.plot(
            b.entropy_deepcabac,
            b.acc,
            label=f"{int(round(n**2)):d} grid points",
            **PLOT_PARAMS_NBATCHES[k + 2],
        )
        k += 1
    unif = pareto_front(
        df.query(
            "name == 'uniform' & network == 'vgg16' & exp == 'main-findings' & bits > 2.23"
        ).sort_values("entropy_deepcabac")
    )
    ax.plot(
        unif.entropy_deepcabac,
        unif.acc,
        label="OPTQ-RD Optimized RD",
        **PLOT_PARAMS_METHODS["uniform"],
    )
    gptq = df.query(
        "name == 'gptq' & network == 'vgg16' & exp == 'main-findings' & bits > 2.23"
    ).sort_values("entropy_deepcabac")
    ax.plot(
        gptq.entropy_deepcabac,
        gptq.acc,
        label=PLOT_LABELS["gptq"],
        **PLOT_PARAMS_METHODS["gptq"],
    )
    ax.legend()


@save
def figure_sparsity_vgg16(df):
    fig, ax = plt.subplots()
    one_panel(
        df,
        ax,
        "VGG16",
        True,
        "sparsity",
        "acc",
        SPARSITY_LABEL,
        "Top-1 Accuracy",
        ["nncodec", "gptq", "uniform", "alpha_inv_tr"],
        (0, 1.05),
        (0, 1.05),
    )


@save
def figure_sparsity(df):
    print("Plotting sparsity...")
    four_panel(
        df,
        "acc",
        "sparsity",
        "Top-1 Accuracy",
        SPARSITY_LABEL,
        ["nncodec", "gptq", "uniform", "alpha_inv_tr"],
        (0, 1.05),
        (0, 1.05),
        include_baseline=False,
    )


def four_panel(
    df,
    x_axis,
    y_axis,
    xlabel,
    ylabel,
    methods,
    xlim,
    ylim,
    include_baseline=False,
    add_nbins=False,
):
    # Create subplots with shared x and y axes
    fig, axs = plt.subplots(
        2, 2, figsize=(2 * fwidth, 2 * fheight), sharex=True, sharey=True
    )

    for k, r in enumerate(["ResNet18", "ResNet34", "ResNet50", "VGG16"]):
        i = k // 2
        j = k % 2
        one_panel(
            df,
            axs[i, j],
            r,
            k == 0,
            x_axis,
            y_axis,
            xlabel,
            ylabel,
            methods,
            xlim,
            ylim,
            include_baseline,
            add_nbins,
        )

    # Set common labels
    # fig.text(0.5, 0.01, xlabel, ha="center", va="center")
    # fig.text(0.01, 0.5, ylabel, ha="center", va="center", rotation="vertical")
    # fig.legend(loc="upper center", ncol=1, bbox_to_anchor=(0.5, 1.1))
    fig.legend(
        loc="upper center", ncol=len(methods) + 1 if include_baseline else len(methods)
    )
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # type: ignore


def one_panel(
    subset,
    ax,
    net,
    is_first: bool,
    x_axis,
    y_axis,
    xlabel,
    ylabel,
    methods,
    xlim,
    ylim,
    include_baseline=False,
    add_nbins=False,
):
    subset = subset.query(
        f"network == '{net.lower()}' & exp == 'main-findings' & layer_name == 'all'"
    )
    if include_baseline:
        ax.axhline(
            y=BASE_PERFORMANCE[net],
            **PLOT_PARAMS_METHODS["base_performance"],
            label="Original" if is_first else None,
        )

    for method in methods:
        subset_m = subset.query(f"name == '{method}'")
        if method.startswith("gptq"):
            subset_m = sort(subset_m)
        else:
            subset_m = pareto_front(subset_m)

        x = subset_m[x_axis]
        y = subset_m[y_axis]
        label = PLOT_LABELS[method] if is_first else None
        ax.plot(x, y, label=label, **PLOT_PARAMS_METHODS[method])
        if add_nbins and method == "gptq_bz2":
            # Add small numbers next to each marker
            for xi, yi, b in zip(x, y, subset_m.bits):
                xy = (5, -6) if yi > 0.2 or net != "VGG16" else (-3, 3)
                ax.annotate(
                    f"{int(round(b**2)):d}",
                    (xi, yi),
                    textcoords="offset points",
                    xytext=xy,
                    ha="center",
                    fontsize=7,
                    color=PLOT_PARAMS_METHODS[method]["color"],
                    fontweight="bold",
                )

    # ax.set_xlim(0, 1.05)
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.set_title(net)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)


def write_block(
    f, net, method_names, entropies, performances, sparsities, include_cf: bool
):
    method_names = [PLOT_LABELS[method] for method in method_names]

    sorting = np.argsort(method_names)
    method_names = [method_names[i] for i in sorting]
    entropies = [entropies[i] for i in sorting]
    performances = [performances[i] for i in sorting]

    best_entropy_idx = np.argmin(entropies)
    for k, method in enumerate(method_names):
        if k == best_entropy_idx:
            entropy_str = "\\textbf{" + f"{entropies[k]:.2f}" + "}"
            cf_str = "\\textbf{" + f"{32/entropies[k]:.2f}" + "}"
        else:
            entropy_str = f"{entropies[k]:.2f}"
            cf_str = f"{32/entropies[k]:.2f}"
        if include_cf:
            cf_str = f"& {cf_str}"
        else:
            cf_str = ""
        sparsity_str = f"& {sparsities[k]:.2f}" if sparsities is not None else ""
        block_str = f"{net if k == 0 else ''} & {method} & {entropy_str} {cf_str} & {performances[k]:.2f} {sparsity_str}"
        block_str += " \\\\\n"
        f.write(block_str)


def ninety_five_percent_performance_table(
    df, threshold=0.95, include_sparsity: bool = False, include_cf=True
):
    methods = ["nncodec", "gptq", "gptq_bz2", "direct_rd", "uniform", "alpha_inv_tr"]
    plot_methods = list(PLOT_LABELS.keys())
    nmetrics = 2
    if include_sparsity:
        nmetrics += 1
    if include_cf:
        nmetrics += 1
    header_format = f"ll{'c' * nmetrics}"
    headers = "Network & Method & Bits-Per-Weight $\\downarrow$"

    if include_cf:
        headers += " & Compression Factor $\\uparrow$"
    headers += "& Top-1 Accuracy $\\uparrow$"
    if include_sparsity:
        headers += " & Sparsity"
    headers += " \\\\ \n"

    with open(plots_folder / "095_table.tex", "w") as f:
        f.write("\\begin{center}\n")
        f.write("\\begin{tabular}{" + header_format + "}\n")
        f.write("\\toprule\n")
        f.write(headers)
        f.write("\\midrule\n")
        for r in ["ResNet18", "ResNet34", "ResNet50", "VGG16"]:
            subset = df.query(
                f"network == '{r.lower()}' & exp == 'main-findings' & layer_name == 'all' & train_dataset == dataset & name in @plot_methods"
            )
            max_perf = BASE_PERFORMANCE[r]

            names = []
            entropies = []
            performances = []
            sparsities = [] if include_sparsity else None

            for n, g in subset.groupby("name", sort=True):
                if n not in methods:
                    continue
                acceptable = g.query("acc >= @threshold * @max_perf")
                if acceptable.empty:
                    print(f"No acceptable results for {r} and method {n}")
                    continue
                best_entropy_row = acceptable.iloc[
                    acceptable.entropy_deepcabac.argmin()
                ]
                best_acc = best_entropy_row.acc
                best_entropy = best_entropy_row.entropy_deepcabac
                best_sparsity = best_entropy_row.sparsity
                names.append(n)
                entropies.append(best_entropy)
                performances.append(best_acc)
                if sparsities is not None:
                    sparsities.append(best_sparsity)
            write_block(
                f,
                r,
                names,
                entropies,
                performances,
                sparsities,
                include_cf,
            )

            f.write("\\midrule\n")

        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{center}\n")


@save
def figure_transpose_scan_order():
    urq = pd.read_csv("out/cluster/17/resnet18/nncodec-results.csv")
    transpose = pd.read_csv("out/cluster/17/resnet18-transpose/nncodec-results.csv")
    plt.figure()
    plt.plot(transpose.entropy, transpose.acc, "--.", label="Column-Major")
    plt.plot(urq.entropy, urq.acc, "--.", label="Row-Major")
    plt.xlabel(AXIS_LABELS_CV[0])
    plt.ylabel(AXIS_LABELS_CV[1])
    plt.legend()


def substitute_gptq(df):
    df.loc[(df["name"] == "uniform") & (df["lm"] == 0), "name"] = "gptq"


@save
def figure_times_stacked():
    # Create figure and axes
    fig, ax = plt.subplots()
    # Labels and values

    # Labels and values
    labels = ["OPTQ", "OPTQ-RD (encode)", "OPTQ-RD (decode)"]
    step_labels = [
        ["Hessian", "GPTQ Quant."],
        ["Hessian", "OPTQ-RD Quant.", "DeepCABAC Encode"],
        ["DeepCABAC Decode"],
    ]
    handles = {}

    # Plot
    fig, ax = plt.subplots()
    for i, steps in enumerate(step_labels):
        bottom = 0
        bottom_d = 0
        for step in steps:
            v = TIMES[step]
            c = COLORS_BARS[step]
            if v <= 5:
                v_d = 20
            else:
                v_d = v

            handles[step] = ax.bar(
                i, v_d, 0.5, bottom=bottom, color=c, label=step  # edgecolor="black"
            )
            if v > 5:
                ax.text(
                    i + 0.37,
                    bottom_d + v_d / 2,
                    f"{v:d}",
                    ha="center",
                    va="center",
                    color=c,
                    fontsize=6,
                )
            bottom += v
            bottom_d += v_d
        ax.text(
            i, bottom_d + 5, f"{bottom:d}", ha="center", va="bottom", fontweight="heavy"
        )

    # Labels and formatting
    ax.set_xticks(range(3))
    ax.set_xticklabels(labels)
    ax.set_ylabel("Time [s]")
    ax.set_ylim(0, 1900)
    # ax.legend(handles=handles.values())
    handles_keys_reordered = [
        "Hessian",
        "GPTQ Quant.",
        "OPTQ-RD Quant.",
        "DeepCABAC Encode",
        "DeepCABAC Decode",
    ]
    handles_reordered = [handles[k] for k in handles_keys_reordered]
    fig.legend(
        loc="upper center",
        handles=handles_reordered,
        labels=handles_keys_reordered,
        ncol=3,
        bbox_to_anchor=(0.53, 1),
        title_fontsize=7,
        columnspacing=0.65,
        handletextpad=0.4,
    )
    # ax.set_title("Time Breakdown for OPTQ and OPTQ-RD", fontsize=14, fontweight="bold")
    # ax.grid(True, axis="y", linestyle="--", alpha=0.7)
    ax.grid(False)
    # ax.spines["top"].set_visible(False)
    # ax.spines["right"].set_visible(False)

    plt.tight_layout(rect=[0, 0, 1, 0.87])  # type: ignore


@save
def figure_times():
    fig, ax = plt.subplots()
    bars = ax.bar(TIMES.keys(), TIMES.values())
    # Add values on top of each bar
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            height,
            f"{int(height):d}",
            ha="center",
            va="bottom",
        )
    ax.set_ylim(0, 2000)
    # ax.set_
    # ax.set_xlabel()
    ax.set_ylabel("Time [s]")
    ax.set_xticklabels(TIMES.keys(), rotation=-45)
    ax.grid(False)


@app.default
def main(force: bool = False):
    plots_folder.mkdir(parents=True, exist_ok=True)
    print(f"Using {results_folder} as results folder, {plots_folder} as plots folder.")
    df = collect_results(force)
    ninety_five_percent_performance_table(df, include_sparsity=False)
    figure_main_results(df)
    figure_main_results_zoomed(df)
    figure_sparsity(df)
    figure_sparsity_vgg16(df)
    figure_ablation_train_test_shift(df)
    figure_ablation_nbatches(df)
    figure_transpose_scan_order()
    figure_main_results_per_grid(df)
    figure_times()
    figure_times_stacked()


if __name__ == "__main__":
    app()
