{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "from scipy.stats import wasserstein_distance as WD\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "matplotlib.rcParams[\"font.size\"] = \"35\"\n",
    "\n",
    "PATH = Path.cwd() / \"fig\" / \"circle\"\n",
    "PATH.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_mnf = np.load(\"data/mnf_toy_circle_x2.npy\")\n",
    "nll = np.load(\"data/mnf_toy_circle_l2.npy\")\n",
    "x_mnf_marg = np.load(\"data/mnf_toy_circle_marg.npy\")\n",
    "x_profiti = np.load(\"data/profiti_toy_circle_x.npy\")\n",
    "x_profiti_marg = np.load(\"data/profiti_toy_circle_x_nc.npy\")\n",
    "x_gmix_1 = np.load(\"data/gmix_toy_circle_1_x.npy\")\n",
    "x_gmix_1_marg = np.load(\"data/gmix_toy_circle_1_marg.npy\")\n",
    "x_gmix_5 = np.load(\"data/gmix_toy_circle_5_x.npy\")\n",
    "x_gmix_10 = np.load(\"data/gmix_toy_circle_10_x.npy\")\n",
    "x_gmix_15 = np.load(\"data/gmix_toy_circle_15_x.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MI_gau_mnf = 0.5*(WD(x_mnf[:,0], x_mnf_marg[:,0]) + WD(x_mnf[:,1],x_mnf_marg[:,1]))\n",
    "MI_gau_profiti = 0.5*(WD(x_profiti[:,0], x_profiti_marg[:,0]) + WD(x_profiti[:,1],x_profiti_marg[:,1]))\n",
    "MI_gau_gmix = 0.5*(WD(x_gmix_1[:,0], x_gmix_1_marg[:,0]) + WD(x_gmix_1[:,1],x_gmix_1_marg[:,1]))\n",
    "print(MI_gau_mnf, MI_gau_profiti, MI_gau_gmix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "XLIM = np.array([-3, +3])\n",
    "YLIM = np.array([-3, +3])\n",
    "PLIM = np.array([0, +1.25])\n",
    "\n",
    "XTICKS = np.array([-2, -1, 0, 1, 2])\n",
    "YTICKS = np.array([-2, -1, 0, 1, 2])\n",
    "PTICKS = np.array([0, 0.25, 0.5, 0.75, 1])\n",
    "\n",
    "XTICKLABELS = []\n",
    "YTICKLABELS = []\n",
    "PTICKLABELS = []\n",
    "\n",
    "FIGSIZE = (3, 3)\n",
    "GRID = 1000  # gridsize for kde plot\n",
    "NUM = None  # How many samples to plot (None=all)\n",
    "\n",
    "nsams = 2000\n",
    "x_ = np.random.uniform(-1, 1, nsams)\n",
    "y = np.sqrt(1 - x_**2)\n",
    "y[nsams // 2 :] *= -1\n",
    "x_orig = np.concatenate([x_, y], 0) + np.random.randn(nsams * 2) * 0.05\n",
    "y_orig = np.concatenate([y, x_], 0) + np.random.randn(nsams * 2) * 0.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_lim_and_ticks(ax):\n",
    "    ax.set_xlim(XLIM)\n",
    "    ax.set_ylim(YLIM)\n",
    "    ax.set_xticks(XTICKS)\n",
    "    ax.set_yticks(YTICKS)\n",
    "    ax.set_xticklabels(XTICKLABELS)\n",
    "    ax.set_yticklabels(XTICKLABELS)\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_orig, y_orig, s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"gt_circle_joint.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_mnf[:, 0], x_mnf[:, 1], s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_joint.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_profiti[:, 0], x_profiti[:, 1], s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_joint.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_gmix_1[:, 0], x_gmix_1[:, 1], s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_joint.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_ax(ax):\n",
    "    ax.set_xlim(XLIM)\n",
    "    ax.set_ylim(PLIM)\n",
    "    ax.set_xticks(XTICKS)\n",
    "    ax.set_yticks(PTICKS)\n",
    "    ax.set_xticklabels(XTICKLABELS)\n",
    "    ax.set_yticklabels(PTICKLABELS)\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.legend([], [], frameon=False)\n",
    "\n",
    "\n",
    "nbins = 100\n",
    "\n",
    "fig, ax = fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "sns.kdeplot(x_orig, color=\"green\", fill=True, label=\"y1\")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gt_circle_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "sns.kdeplot(y_orig, color=\"green\", fill=True, label=\"y2\")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gt_circle_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_mnf_marg[:NUM, 0], color=\"red\", fill=True, label=\"predicted marginal\")\n",
    "sns.kdeplot(x_mnf[:NUM, 0], color=\"blue\", fill=True, label=\"integrated marginal\")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_y1_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_mnf_marg[:NUM, 1], color=\"red\", fill=True, label=\"predicted marginal\")\n",
    "sns.kdeplot(x_mnf[:NUM, 1], color=\"blue\", fill=True, label=\"integrated marginal\")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_y2_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_mnf[:NUM, 0], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_int_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_mnf[:NUM, 1], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_int_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_mnf_marg[:NUM, 0], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_pred_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_mnf_marg[:NUM, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_circle_pred_y2.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_profiti_marg[:NUM, 0],\n",
    "    color=\"red\",\n",
    "    fill=True,\n",
    "    label=\"predicted marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_profiti[:NUM, 0],\n",
    "    color=\"blue\",\n",
    "    fill=True,\n",
    "    label=\"integrated marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_y1_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_profiti_marg[:NUM, 1],\n",
    "    color=\"red\",\n",
    "    fill=True,\n",
    "    label=\"predicted marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_profiti[:NUM, 1],\n",
    "    color=\"blue\",\n",
    "    fill=True,\n",
    "    label=\"integrated marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_y2_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_profiti[:NUM, 0], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_int_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_profiti[:NUM, 1], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_int_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_profiti_marg[:NUM, 0], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_pred_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_profiti_marg[:NUM, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_circle_pred_y2.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:NUM, 0],\n",
    "    color=\"red\",\n",
    "    fill=True,\n",
    "    label=\"predicted marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:NUM, 0],\n",
    "    color=\"blue\",\n",
    "    fill=True,\n",
    "    label=\"integrated marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_y1_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:NUM, 1],\n",
    "    color=\"red\",\n",
    "    fill=True,\n",
    "    label=\"predicted marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:NUM, 1],\n",
    "    color=\"blue\",\n",
    "    fill=True,\n",
    "    label=\"integrated marginal\",\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_y2_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:NUM, 0], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_int_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:NUM, 1], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_int_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:NUM, 0], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_pred_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:NUM, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_circle_pred_y2.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gmix\n",
    "fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)\n",
    "ax.scatter(x_orig, y_orig, s=1, c=\"orange\")\n",
    "ax.set_xlim(-1.2, 1.2)\n",
    "ax.set_ylim(-1.2, 1.2)\n",
    "ax.set_xticks([], [])\n",
    "ax.set_yticks([], [])\n",
    "fig.savefig(PATH / \"true_dist_circle.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)\n",
    "ax.scatter(x_mnf[:4000, 0], x_mnf[:4000, 1], s=1, c=\"orange\")\n",
    "ax.set_xlim(-1.2, 1.2)\n",
    "ax.set_ylim(-1.2, 1.2)\n",
    "ax.set_xticks([], [])\n",
    "ax.set_yticks([], [])\n",
    "fig.savefig(PATH / \"mymodel_circle.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)\n",
    "ax.scatter(x_gmix_1[:4000, 0], x_gmix_1[:4000, 1], s=1, c=\"orange\")\n",
    "ax.set_xlim(-1.2, 1.2)\n",
    "ax.set_ylim(-1.2, 1.2)\n",
    "ax.set_xticks([], [])\n",
    "ax.set_yticks([], [])\n",
    "fig.savefig(PATH / \"g_mix_1_circle.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)\n",
    "ax.scatter(x_gmix_5[:4000, 0], x_gmix_5[:4000, 1], s=1, c=\"orange\")\n",
    "ax.set_xlim(-1.2, 1.2)\n",
    "ax.set_ylim(-1.2, 1.2)\n",
    "ax.set_xticks([], [])\n",
    "ax.set_yticks([], [])\n",
    "fig.savefig(PATH / \"g_mix_5_circle.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)\n",
    "ax.scatter(x_gmix_10[:4000, 0], x_gmix_10[:4000, 1], s=1, c=\"orange\")\n",
    "ax.set_xlim(-1.2, 1.2)\n",
    "ax.set_ylim(-1.2, 1.2)\n",
    "ax.set_xticks([], [])\n",
    "ax.set_yticks([], [])\n",
    "fig.savefig(PATH / \"g_mix_10_circle.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True)\n",
    "ax.scatter(x_gmix_15[:4000, 0], x_gmix_15[:4000, 1], s=1, c=\"orange\")\n",
    "ax.set_xlim(-1.2, 1.2)\n",
    "ax.set_ylim(-1.2, 1.2)\n",
    "ax.set_xticks([], [])\n",
    "ax.set_yticks([], [])\n",
    "fig.savefig(PATH / \"g_mix_15_circle.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
