{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "from scipy.stats import wasserstein_distance as WD\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "matplotlib.rcParams[\"font.size\"] = \"35\"\n",
    "\n",
    "PATH = Path.cwd() / \"fig\" / \"toy\"\n",
    "PATH.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_mnf = np.load(\"data/mnf_toy_gaussian2_x2.npy\")\n",
    "nll = np.load(\"data/mnf_toy_gaussian2_l2.npy\")\n",
    "x_mnf_marg = np.load(\"data/mnf_toy_gaussian2_marg.npy\")\n",
    "x_profiti = np.load(\"data/profiti_toy_gaussian2_x.npy\")\n",
    "x_profiti_marg = np.load(\"data/profiti_toy_gaussian2_x_nc.npy\")\n",
    "x_gmix_1 = np.load(\"data/gmix_toy_gaussian2_1_x.npy\")\n",
    "x_gmix_1_marg = np.load(\"data/gmix_toy_gaussian2_1_marg.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MI_gau_mnf = 0.5*(WD(x_mnf[:,0], x_mnf_marg[:,0]) + WD(x_mnf[:,1],x_mnf_marg[:,1]))\n",
    "MI_gau_profiti = 0.5*(WD(x_profiti[:,0], x_profiti_marg[:,0]) + WD(x_profiti[:,1],x_profiti_marg[:,1]))\n",
    "MI_gau_gmix = 0.5*(WD(x_gmix_1[:,0], x_gmix_1_marg[:,0]) + WD(x_gmix_1[:,1],x_gmix_1_marg[:,1]))\n",
    "print(MI_gau_mnf, MI_gau_profiti, MI_gau_gmix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "XLIM = np.array([-10, +10])\n",
    "YLIM = np.array([-10, +10])\n",
    "PLIM = np.array([0, +1.25])\n",
    "\n",
    "XTICKS = np.array([-20 / 3, -10 / 3, 0, 10 / 3, 20 / 3])\n",
    "YTICKS = np.array([-20 / 3, -10 / 3, 0, 10 / 3, 20 / 3])\n",
    "PTICKS = np.array([0, 0.25, 0.5, 0.75, 1])\n",
    "\n",
    "XTICKLABELS = []\n",
    "YTICKLABELS = []\n",
    "PTICKLABELS = []\n",
    "\n",
    "FIGSIZE = (3, 3)\n",
    "GRID = 1000  # gridsize for kde plot\n",
    "NUM = None  # How many samples to plot (None=all)\n",
    "\n",
    "n = 6000\n",
    "x_orig = np.random.normal(0, 1, [n, 2, 1])\n",
    "cov = np.array([[1, 0], [1, 1]])\n",
    "\n",
    "y_orig = np.matmul(cov, x_orig)\n",
    "y_orig = np.sign(y_orig) * (y_orig**2)\n",
    "x_orig = y_orig[:, 0].squeeze()\n",
    "y_orig = y_orig[:, 1].squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# RESCALE (so that same axes as circle plot)\n",
    "SCALE = 3 / 10\n",
    "XLIM = XLIM * SCALE\n",
    "YLIM = YLIM * SCALE\n",
    "XTICKS = XTICKS * SCALE\n",
    "YTICKS = YTICKS * SCALE\n",
    "\n",
    "x_orig = x_orig * SCALE\n",
    "y_orig = y_orig * SCALE\n",
    "x_mnf = x_mnf * SCALE\n",
    "x_mnf_marg = x_mnf_marg * SCALE\n",
    "x_profiti = x_profiti * SCALE\n",
    "x_profiti_marg = x_profiti_marg * SCALE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_lim_and_ticks(ax):\n",
    "    ax.set_xlim(XLIM)\n",
    "    ax.set_ylim(YLIM)\n",
    "    ax.set_xticks(XTICKS)\n",
    "    ax.set_yticks(YTICKS)\n",
    "    ax.set_xticklabels(XTICKLABELS)\n",
    "    ax.set_yticklabels(XTICKLABELS)\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_orig, y_orig, s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"gt_toy_joint.pdf\")\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_mnf[:, 0], x_mnf[:, 1], s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"mnf_toy_joint.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_profiti[:, 0], x_profiti[:, 1], s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_joint.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "ax.scatter(x_gmix_1[:, 0], x_gmix_1[:, 1], s=1, c=\"orange\")\n",
    "set_lim_and_ticks(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_joint.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_ax(ax):\n",
    "    ax.set_xlim(XLIM)\n",
    "    ax.set_ylim(PLIM)\n",
    "    ax.set_xticks(XTICKS)\n",
    "    ax.set_yticks(PTICKS)\n",
    "    ax.set_xticklabels(XTICKLABELS)\n",
    "    ax.set_yticklabels(PTICKLABELS)\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.legend([], [], frameon=False)\n",
    "\n",
    "\n",
    "nbins = 100\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "sns.kdeplot(x_orig, color=\"green\", label=\"y1\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gt_toy_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "sns.kdeplot(y_orig, color=\"green\", label=\"y2\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gt_toy_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_mnf_marg[:, 0], color=\"red\", label=\"predicted marginal\", fill=True, gridsize=GRID\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_mnf[:, 0], color=\"blue\", label=\"integrated marginal\", fill=True, gridsize=GRID\n",
    ")\n",
    "format_ax(ax)\n",
    "ax.legend([], [], frameon=False)\n",
    "\n",
    "fig.savefig(PATH / \"mnf_toy_y1_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_mnf_marg[:, 1], color=\"red\", label=\"predicted marginal\", fill=True, gridsize=GRID\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_mnf[:, 1], color=\"blue\", label=\"integrated marginal\", fill=True, gridsize=GRID\n",
    ")\n",
    "format_ax(ax)\n",
    "ax.legend([], [], frameon=False)\n",
    "\n",
    "fig.savefig(PATH / \"mnf_toy_y2_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_mnf[:, 0], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_toy_int_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_mnf[:, 1], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "ax.legend([], [], frameon=False)\n",
    "\n",
    "fig.savefig(PATH / \"mnf_toy_int_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_mnf_marg[:, 0], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_toy_pred_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_mnf_marg[:, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"mnf_toy_pred_y2.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_profiti_marg[:2000, 0],\n",
    "    color=\"red\",\n",
    "    label=\"predicted marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_profiti[:2000, 0],\n",
    "    color=\"blue\",\n",
    "    label=\"integrated marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_y1_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_profiti_marg[:2000, 1],\n",
    "    color=\"red\",\n",
    "    label=\"predicted marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_profiti[:2000, 1],\n",
    "    color=\"blue\",\n",
    "    label=\"integrated marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_y2_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_profiti[:2000, 0], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_int_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_profiti[:2000, 1], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_int_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_profiti_marg[:2000, 0], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_pred_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_profiti_marg[:2000, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"profiti_toy_pred_y2.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_ax(ax):\n",
    "    ax.set_xlim(XLIM)\n",
    "    ax.set_ylim(PLIM)\n",
    "    ax.set_xticks(XTICKS)\n",
    "    ax.set_yticks(PTICKS)\n",
    "    ax.set_xticklabels(XTICKLABELS)\n",
    "    ax.set_yticklabels(PTICKLABELS)\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.legend([], [], frameon=False)\n",
    "\n",
    "\n",
    "nbins = 100\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:2000, 0],\n",
    "    color=\"red\",\n",
    "    label=\"predicted marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:2000, 0],\n",
    "    color=\"blue\",\n",
    "    label=\"integrated marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_y1_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:2000, 1],\n",
    "    color=\"red\",\n",
    "    label=\"predicted marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "sns.kdeplot(\n",
    "    x_gmix_1_marg[:2000, 1],\n",
    "    color=\"blue\",\n",
    "    label=\"integrated marginal\",\n",
    "    fill=True,\n",
    "    gridsize=GRID,\n",
    ")\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_y2_both.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:2000, 0], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_int_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:2000, 1], color=\"blue\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_int_y2.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,0], bins = nbins, density=True, color='green', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:2000, 0], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_pred_y1.pdf\")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_gmix_1_marg[:2000, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)\n",
    "fig.savefig(PATH / \"gmix_toy_pred_y2.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=FIGSIZE, constrained_layout=True)\n",
    "# ax.hist(x_marg[:2000,1], bins = nbins, density=True, color='blue', alpha=0.4);\n",
    "sns.kdeplot(x_profiti_marg[:, 1], color=\"red\", fill=True, gridsize=GRID)\n",
    "format_ax(ax)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
