{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import numpy as np\n",
    "from jax import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate a mixture of Gaussians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key+=1\n",
    "\n",
    "\n",
    "# define mean and variances of two Gaussians\n",
    "m1 = -5.0; s1 = 1.5\n",
    "m2 = 2.0; s2 = 1.1\n",
    "m3 = 3.0; s3 = 0.6\n",
    "m4 = 7.0; s4 = 1.7\n",
    "m5 = 9.0; s5 = 1.4\n",
    "m6 = 15; s6 = 0.9\n",
    "m7 = 18; s7 = 1.4\n",
    "m8 = 23; s7 = 0.5\n",
    "\n",
    "# Define weights\n",
    "w1 = 3\n",
    "w2 = 4\n",
    "w3 = 2\n",
    "w4 = 6\n",
    "w5 = 3\n",
    "w6 = 4\n",
    "w7 = 3\n",
    "w8 = 2\n",
    "norm = w1 + w2 + w3 + w4 + w5 + w6 + w7 + w8\n",
    "w1, w2, w3, w4, w5, w6, w7, w8 = w1/norm, w2/norm, w3/norm, w4/norm, w5/norm, w6/norm, w7/norm, w8/norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of samples\n",
    "n_samples = 100000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate some sample data, with proportions given by w1, w2, w3, ...\n",
    "Effect = np.concatenate([m1 + s1*random.normal(key, (int(w1*n_samples),)), \n",
    "                         m2 + s2*random.normal(key, (int(w2*n_samples),)),\n",
    "                         m3 + s3*random.normal(key, (int(w3*n_samples),)),\n",
    "                         m4 + s4*random.normal(key, (int(w4*n_samples),)),\n",
    "                         m5 + s5*random.normal(key, (int(w5*n_samples),)),\n",
    "                         m6 + s6*random.normal(key, (int(w6*n_samples),)),\n",
    "                         m7 + s7*random.normal(key, (int(w7*n_samples),))                         \n",
    "                        ]\n",
    "                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot histogram\n",
    "import matplotlib.pyplot as plt\n",
    "plt.hist(Effect, bins=50)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Effect.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sort the values and normalize them between 0 and 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.min(Effect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(Effect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Effect = np.sort(Effect)\n",
    "Effect-=np.min(Effect)\n",
    "Effect/=np.max(Effect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.min(Effect))\n",
    "print(np.max(Effect))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the empirical CDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from statsmodels.distributions.empirical_distribution import ECDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecdf = ECDF(Effect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Cause = ecdf(Effect)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Make x-y plot, also showing the densities.\n",
    "\n",
    "Reference: https://matplotlib.org/stable/gallery/lines_bars_and_markers/scatter_hist.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "\n",
    "# Cause in the x-axis, Effect in the y-axis\n",
    "x = Cause\n",
    "y = Effect\n",
    "\n",
    "# Add Kernel Density Estimation for smoother densities on x and y\n",
    "kde_x = stats.gaussian_kde(x)\n",
    "kde_y = stats.gaussian_kde(y)\n",
    "\n",
    "xx = np.linspace(0, 1, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "sns.set_style(\"white\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure_path = \"/Users/luigigresele/Documents/Plots_IMA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import MaxNLocator, MultipleLocator, Locator, FixedLocator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "mpl.rcParams['axes.linewidth'] = 1.7\n",
    "mpl.rcParams['axes.edgecolor'] = 'b'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# definitions for the axes\n",
    "left = 0.4\n",
    "bottom = 0.4\n",
    "\n",
    "# Width and height of the x-y plot\n",
    "width, height = 1.0, 0.65\n",
    "\n",
    "# Spacing between the axes of the x-y plot and the histograms:\n",
    "spacing = 0.03\n",
    "\n",
    "\n",
    "rect_scatter = [left, bottom, width, height]\n",
    "\n",
    "# Distance from axes\n",
    "distance = 0.15\n",
    "\n",
    "rect_histy = [left - distance - spacing, bottom, 0.15, height]\n",
    "\n",
    "# start with a rectangular Figure\n",
    "plt.figure(figsize=(8, 8))\n",
    "\n",
    "with sns.axes_style(\"white\"):\n",
    "    ax_scatter = plt.axes(rect_scatter)\n",
    "    ax_scatter.yaxis.set_ticks([])\n",
    "    # Remove ticks from x axis\n",
    "    ax_scatter.xaxis.set_ticks([])\n",
    "#     Ticks only in 0.0 and 1.0\n",
    "    ax_scatter.xaxis.set_major_locator(FixedLocator([0.0, 1.0]))\n",
    "\n",
    "ax_scatter.tick_params(direction='in', top=True, right=True, labelsize=30)\n",
    "\n",
    "with sns.axes_style(\"white\"):\n",
    "    ax_histy = plt.axes(rect_histy)\n",
    "    ax_histy.xaxis.set_ticks([])\n",
    "\n",
    "# ax_histy.yaxis.grid(True) # Hide the horizontal gridlines\n",
    "# ax_histy.xaxis.grid(False) # Show the vertical gridlines\n",
    "\n",
    "ax_histy.tick_params(direction='in', labelleft=False)\n",
    "\n",
    "# the scatter plot:\n",
    "# ax_scatter.scatter(x, y)\n",
    "with sns.axes_style(\"white\"):\n",
    "    ax_scatter.plot(x, y, linewidth=4)\n",
    "\n",
    "# now determine nice limits by hand:\n",
    "binwidth = 0.01\n",
    "lim = 1\n",
    "ax_scatter.set_xlim((0, 1))\n",
    "ax_scatter.set_ylim((0, 1))\n",
    "\n",
    "bins = np.arange(0, 1+binwidth, binwidth)\n",
    "\n",
    "# Add KDE plot of the density for y\n",
    "ax_histy.plot(kde_y(xx), xx,  color='C0', linewidth=4)\n",
    "ax_histy.fill_betweenx(xx, kde_y(xx), # xx,\n",
    "              color='C0', alpha=0.3)\n",
    "\n",
    "# ax_histx.set_xlim(ax_scatter.get_xlim())\n",
    "ax_histy.set_ylim(ax_scatter.get_ylim())\n",
    "\n",
    "# Insert text here\n",
    "fontsize_math = 25\n",
    "plt.text(0.6, 0.3, r\"$p_x(x)$\", fontsize=fontsize_math)\n",
    "plt.text(10.5, 0.04, r\"$p_y(y)=1$\", fontsize=fontsize_math)\n",
    "plt.text(9.0, 0.65, r\"$x=g^{-1}(y)$\", fontsize=fontsize_math)\n",
    "plt.text(11.5, 0.5, r\"$y=g(x):\\!\\!\\!= P(X \\leq x)$\", fontsize=fontsize_math)\n",
    "\n",
    "plt.savefig(os.path.join(figure_path, 'IGCI_plot.png'), dpi=None, facecolor='w', edgecolor='w',\n",
    "            orientation='portrait', format=None,\n",
    "            transparent=True, bbox_inches='tight', pad_inches=0.1, metadata=None)\n",
    "\n",
    "plt.savefig(os.path.join(figure_path, 'IGCI_plot.pdf'), dpi=None, facecolor='w', edgecolor='w',\n",
    "            orientation='portrait', format=None,\n",
    "            transparent=True, bbox_inches='tight', pad_inches=0.1, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Maybe we can still insert some text (not sure what the best notation is, cause/effect, observation/source, or something neutral), e.g.:\n",
    "\n",
    "# p_x(x) for the density on the y-axis\n",
    "# g(x) := F_X(x) = P(X<=x) for the function\n",
    "# p_y(y) = 1 for the implied uniform density on the x-axis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
