import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
# from opacus import PrivacyEngine
from pathlib import Path
model_map = {
    "logreg": "LogReg",
    "priv_mlp": "DP-MLP",
    "tf_mlp": "DP-MLP",
    "mlp": "MLP",
    "beta": "BetaNoised",
    "beta_unbiased": "BetaDebiased",
    "norm_dwork": "LogRegNormal",
    "norm_borja": "LogRegNormal",
    "lapl": "LogRegLaplace",
    "generator": "True",
    "none": "None",
}


def kde2d_plot(features, target=None, weights=None, name="test"):
    if target is not None:
        target.name = "target"
    g = sns.kdeplot(
        x=0, y=1, data=features, fill=True, hue=target, alpha=0.6, weights=weights
    )
    g.set_xlim([0, 1])
    g.set_ylim([0, 1])
    plt.xlabel("$x_0$")
    plt.ylabel("$x_1$")
    plt.savefig(f"output/{name}.png")
    plt.close()


def kde2d_plot_all_weights(
    features,
    target=None,
    log_weights_dict=None,
    features_true=None,
    target_true=None,
    name="test",
):
    col = int(features_true is not None)
    plt.rcParams.update({"font.size": 8})
    fig, axes = plt.subplots(nrows=2, ncols=len(log_weights_dict) + col)

    for k, v in log_weights_dict.items():
        axes[0, col].hist(v)
        axes[0, col].set_title(model_map[k])

        g = sns.kdeplot(
            x=0,
            y=1,
            data=features,
            fill=True,
            hue=target,
            alpha=0.6,
            weights=np.exp(v),
            ax=axes[1, col],
            legend=(col == len(log_weights_dict)),
        )
        axes[1, col].xaxis.label.set_visible(False)
        axes[1, col].yaxis.label.set_visible(False)
        # axes[1, col].set_xticklabels([])
        # axes[1, col].set_yticklabels([])
        # axes[1, col].set_xticks([])
        # axes[1, col].set_yticks([])
        g.set_xlim([0, 1])
        g.set_ylim([0, 1])

        col += 1

    if features_true is not None:
        g = sns.kdeplot(
            x=0,
            y=1,
            data=features_true,
            fill=True,
            hue=target_true,
            alpha=0.6,
            ax=axes[1, 0],
            legend=False,
        )
        fig.delaxes(axes[0][0])
        g.set_xlim([0, 1])
        g.set_ylim([0, 1])

    axes[1, 0].set(ylabel="$x_1$")
    axes[1, 0].set(xlabel="$x_0$")

    fig.set_size_inches((10, 3))
    fig.tight_layout()

    Path("output").mkdir(parents=True, exist_ok=True)
    Path("output/viz").mkdir(parents=True, exist_ok=True)

    name = name + f"_all"
    fig.savefig(f"output/viz/{name}.pdf", format="pdf")

    return
