{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make ground-truth distribution for the plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# p hat is a mixture of Gaussians\n",
    "ws_hat = [.6,.4]\n",
    "sds_hat = [1,2]\n",
    "us_hat = [4,-6]\n",
    "\n",
    "u_hat = 0\n",
    "\n",
    "#calculate its standard deviation\n",
    "sd_hat = np.sqrt(ws_hat[0]*sds_hat[0]**2 + ws_hat[1]*sds_hat[1]**2 +ws_hat[0]*us_hat[0]**2+ws_hat[1]*us_hat[1]**2-(ws_hat[0]*us_hat[0]+ws_hat[1]*us_hat[1])**2)\n",
    "\n",
    "def p_hat(x):\n",
    "    return np.dot(ws_hat, [np.exp(-.5*((x-u)/sd)**2)/np.sqrt(2*np.pi*sd**2) for u,sd in zip(us_hat,sds_hat)]) \n",
    "\n",
    "# p_1 and p_2 are  both Gaussians, one with matching standard deviation, both with matching means\n",
    "sd_1=sd_hat\n",
    "u1 = 0\n",
    "sd_2=3\n",
    "u2 = 0\n",
    "\n",
    "def p_1(x):\n",
    "    return np.exp(-.5*((x-u1)/sd_1)**2)/np.sqrt(2*np.pi*sd_1**2)\n",
    "\n",
    "def p_2(x):\n",
    "    return np.exp(-.5*((x-u2)/sd_2)**2)/np.sqrt(2*np.pi*sd_2**2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define colors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c1 = '#6FABD3'#'#3B6FA7'\n",
    "c2 = '#F26C5D'#9A3740'\n",
    "c3 = '#A0D39C'#'#8BCB8E' #347E52\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot each distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with mpl.rc_context(fname='../../../matplotlibrc'):\n",
    "    x = np.linspace(-20,20,1000)\n",
    "    alpha = .2\n",
    "    lw = 1.5\n",
    "    plt.figure(figsize=(2,1.5))\n",
    "    plt.plot(x,p_hat(x), color =c1,lw=lw, label =r'$p_{true}(x)$')\n",
    "    plt.fill_between(x,p_hat(x),alpha=alpha,color=c1)\n",
    "    plt.plot(x,p_1(x),color=c2,lw=lw, label =r'$p_1(x)$')\n",
    "    plt.fill_between(x,p_1(x),alpha=alpha,color = c2)\n",
    "    plt.plot(x,p_2(x),color=c3,lw=lw, label =r'$p_2(x)$')\n",
    "    plt.fill_between(x,p_2(x),alpha=alpha,color=c3)\n",
    "    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n",
    "    plt.xlim(-15,15)\n",
    "    plt.ylim(-.1,.25)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.axhline(0, lw=1.5,color='black',zorder =20)\n",
    "    plt.ylabel(r\"$p(x)$\")\n",
    "\n",
    "    # set z spine off\n",
    "    plt.gca().spines['bottom'].set_visible(False)\n",
    "    plt.xlabel(r\"$x$\")\n",
    "    #plt.savefig(\"mixture.pdf\",bbox_inches='tight',transparent=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make panel with identity mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with mpl.rc_context(fname='../../../matplotlibrc'):\n",
    "    x = np.linspace(-20,20,1000)\n",
    "\n",
    "    alpha = .2\n",
    "    lw = 1.5\n",
    "    plt.figure(figsize=(2,1.5))\n",
    "    plt.plot(x,p_hat(x), color =c1,lw=lw)\n",
    "    plt.fill_between(x,p_hat(x),alpha=alpha,color=c1)\n",
    "    plt.plot(x,p_1(x),color=c2,lw=lw)\n",
    "    plt.fill_between(x,p_1(x),alpha=alpha,color = c2)\n",
    "    plt.plot(x,p_2(x),color=c3,lw=lw)\n",
    "    plt.fill_between(x,p_2(x),alpha=alpha,color=c3)\n",
    "    plt.xlim(-15,15)\n",
    "    plt.ylim(-.1,.25)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.axvline(0, lw=1.5,ls='--',color='grey',zorder =-20)\n",
    "    plt.axhline(0, lw=1.5,color='black',zorder =20)\n",
    "\n",
    "    m1 = '+'\n",
    "    m2='2'\n",
    "    m3='*'\n",
    "    lw1=3\n",
    "    lw2=1.5\n",
    "    lw3=.2\n",
    "    ms1=100\n",
    "    ms2=120\n",
    "    ms3=70\n",
    "\n",
    "    plt.scatter([0],[-.05],color = c1,marker =m1,s=ms1,lw=lw1, zorder = -10,label=r'$\\mathbb{E}_{p_{true}}[\\phi]$')\n",
    "    plt.scatter([0],[-.05],color = c2,marker =m2, s=ms2,zorder=2, lw=lw2,label=r'$\\mathbb{E}_{p_1}[\\phi]$')\n",
    "    plt.scatter([0],[-.05],color = c3,marker = m3,s=ms3,zorder=2, lw=lw3, label=r'$\\mathbb{E}_{p_2}[\\phi]$')\n",
    "    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n",
    "    plt.title(r\"$\\phi(x) = x$\")\n",
    "\n",
    "    # set z spine off\n",
    "    plt.gca().spines[['bottom']].set_visible(False)\n",
    "    plt.xlabel(r\"$\\phi(x)$\")\n",
    "    plt.ylabel(r\"$p(x)$\")\n",
    "    #plt.savefig(\"mixture2.pdf\",bbox_inches='tight',transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#FUNCTION FROM: https://github.com/artmenlope/matplotlib-fill_between-in-3D\n",
    "\n",
    "\"\"\"\n",
    "@author: artmenlope\n",
    "\"\"\"\n",
    "\n",
    "def fill_between_3d(ax,x1,y1,z1,x2,y2,z2,mode=1,c='steelblue',alpha=0.6):\n",
    "    \"\"\"\n",
    "    Function similar to the matplotlib.pyplot.fill_between function but \n",
    "    for 3D plots.\n",
    "       \n",
    "    input:\n",
    "        \n",
    "        ax -> The axis where the function will plot.\n",
    "        \n",
    "        x1 -> 1D array. x coordinates of the first line.\n",
    "        y1 -> 1D array. y coordinates of the first line.\n",
    "        z1 -> 1D array. z coordinates of the first line.\n",
    "        \n",
    "        x2 -> 1D array. x coordinates of the second line.\n",
    "        y2 -> 1D array. y coordinates of the second line.\n",
    "        z2 -> 1D array. z coordinates of the second line.\n",
    "    \n",
    "    modes:\n",
    "\n",
    "        mode = 1 -> Fill between the lines using the shortest distance between \n",
    "                    both. Makes a lot of single trapezoids in the diagonals \n",
    "                    between lines and then adds them into a single collection.\n",
    "                    \n",
    "        mode = 2 -> Uses the lines as the edges of one only 3d polygon.\n",
    "           \n",
    "    Other parameters (for matplotlib): \n",
    "        \n",
    "        c -> the color of the polygon collection.\n",
    "        alpha -> transparency of the polygon collection.  \n",
    "    \"\"\"\n",
    "\n",
    "    if mode == 1:\n",
    "        \n",
    "        for i in range(len(x1)-1):\n",
    "            \n",
    "            verts = [(x1[i],y1[i],z1[i]), (x1[i+1],y1[i+1],z1[i+1])] + \\\n",
    "                    [(x2[i+1],y2[i+1],z2[i+1]), (x2[i],y2[i],z2[i])]\n",
    "            \n",
    "            ax.add_collection3d(Poly3DCollection([verts],\n",
    "                                                 alpha=alpha,\n",
    "                                                 linewidths=0,\n",
    "                                                 color=c))\n",
    "\n",
    "    if mode == 2:\n",
    "        \n",
    "        verts = [(x1[i],y1[i],z1[i]) for i in range(len(x1))] + \\\n",
    "                [(x2[i],y2[i],z2[i]) for i in range(len(x2))]\n",
    "                \n",
    "        ax.add_collection3d(Poly3DCollection([verts],alpha=alpha,color=c))\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply polynomial mapping and estimate the new PDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 10000\n",
    "sample_p_hat = np.concatenate([np.random.normal(us_hat[0],sds_hat[0],int(n_samples*ws_hat[0])),\n",
    "                                 np.random.normal(us_hat[1],sds_hat[1],int(n_samples*ws_hat[1]))])\n",
    "\n",
    "sample_p1 = np.random.normal(0,sd_1,n_samples)\n",
    "sample_p2 = np.random.normal(0,sd_2,n_samples)\n",
    "\n",
    "data_p_hat = np.array([sample_p_hat,sample_p_hat**2]).T\n",
    "data_p1 = np.array([sample_p1,sample_p1**2]).T\n",
    "data_p2 = np.array([sample_p2,sample_p2**2]).T\n",
    "\n",
    "df_p_hat = pd.DataFrame(data_p_hat, columns=[r'$\\phi_1$',r'$\\phi_2$'])\n",
    "df_p1 = pd.DataFrame(data_p1, columns=[r'$\\phi_1$',r'$\\phi_2$'])\n",
    "df_p2 = pd.DataFrame(data_p2, columns=[r'$\\phi_1$',r'$\\phi_2$'])\n",
    "\n",
    "mean_p_hat = data_p_hat.mean(axis=0)\n",
    "mean_p1 = data_p1.mean(axis=0)\n",
    "mean_p2 = data_p2.mean(axis=0)\n",
    "\n",
    "width = .3\n",
    "phi_hat = stats.gaussian_kde(data_p_hat[:,1],bw_method=width).pdf\n",
    "phi_p1 = stats.gaussian_kde(data_p1[:,1],bw_method=width).pdf\n",
    "phi_p2 = stats.gaussian_kde(data_p2[:,1],bw_method=width).pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make panel with polynomial mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 1\n",
    "color='grey'\n",
    "y_move =10\n",
    "\n",
    "xlims = [-12,12]\n",
    "ylims = [eps,55]\n",
    "\n",
    "with mpl.rc_context(fname='../../../matplotlibrc'):\n",
    "\n",
    "    fig = plt.figure(figsize=(3,2))\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    ax.set_box_aspect(aspect=(2, 2, 1))\n",
    "\n",
    "    x = np.linspace(xlims[0],xlims[1],1000)\n",
    "\n",
    "    y = np.zeros_like(x)\n",
    "    y1 = y+eps\n",
    "    y2 = y+eps*2\n",
    "    set_ground = [-x,y,np.zeros_like(x)]\n",
    "    set_ground1 = [-x,y1,np.zeros_like(x)]\n",
    "    set_ground2 = [-x,y2,np.zeros_like(x)]\n",
    "    set_p_hat = [-x,y,p_hat(x)]\n",
    "    set_p1 = [-x,y1,p_1(x)]\n",
    "    set_p2 = [-x,y2,p_2(x)]\n",
    "\n",
    "    ax.plot(*set_ground2,color=color,lw=lw)\n",
    "    ax.plot(*set_p_hat,lw=lw,color=c1)\n",
    "    ax.plot(*set_p1,lw=lw,color=c2)\n",
    "    ax.plot(*set_p2,lw=lw,color=c3)\n",
    "\n",
    "    alpha = .1\n",
    "    fill_between_3d(ax, *set_ground, *set_p_hat, mode = 1,c=c1,alpha=alpha)\n",
    "    fill_between_3d(ax, *set_ground1, *set_p1, mode = 1,c=c2,alpha=alpha)\n",
    "    fill_between_3d(ax, *set_ground2, *set_p2, mode = 1,c=c3,alpha=alpha)\n",
    "\n",
    "    y = np.linspace(ylims[0],ylims[1],1000)\n",
    "    yl = y\n",
    "    y_sc = 2.5\n",
    "    eps = .2\n",
    "    x = np.zeros_like(y)+xlims[1]\n",
    "    x1 = x+eps\n",
    "    x2 = x+eps*2\n",
    "    set_ground = [x,yl,np.zeros_like(x)]\n",
    "    set_ground1 = [x1,yl,np.zeros_like(x)]\n",
    "    set_ground2 = [x2,yl,np.zeros_like(x)]\n",
    "    set_p_hat = [x,yl,phi_hat(y)*y_sc]\n",
    "    set_p1 = [x1,yl,phi_p1(y)*y_sc]\n",
    "    set_p2 = [x2,yl,phi_p2(y)*y_sc]\n",
    "\n",
    "    ax.plot(*set_ground2,color=color,lw=lw)\n",
    "    ax.plot(*set_p_hat,lw=lw,color=c1)\n",
    "    ax.plot(*set_p1,lw=lw,color=c2)\n",
    "    ax.plot(*set_p2,lw=lw,color=c3)\n",
    "\n",
    "    alpha = .1\n",
    "    fill_between_3d(ax, *set_ground, *set_p_hat, mode = 1,c=c1,alpha=alpha)\n",
    "    fill_between_3d(ax, *set_ground1, *set_p1, mode = 1,c=c2,alpha=alpha)\n",
    "    fill_between_3d(ax, *set_ground2, *set_p2, mode = 1,c=c3,alpha=alpha)\n",
    "\n",
    "    color='grey'\n",
    "\n",
    "    lw = 1.5\n",
    "    sc=.2\n",
    "    plt.plot([xlims[0],xlims[1]],[ylims[0],ylims[0]],[0,0],c='black',lw=lw,zorder =1000)\n",
    "    plt.plot([xlims[1],xlims[1]],[ylims[0],ylims[1]],[0,0],c='black',lw=lw,zorder =1000)\n",
    "    plt.plot([xlims[1],xlims[1]],[ylims[0],ylims[0]],[0,max(np.max(phi_p2(y)*y_sc),.2)],c='black',lw=lw,zorder =1000)\n",
    "\n",
    "    ax.scatter(mean_p_hat[0],mean_p_hat[1]+y_move,0,c=c1,\n",
    "                marker=m1,s=ms1,zorder=0,lw=lw1)\n",
    "    ax.scatter(mean_p2[0],mean_p2[1]+y_move,0,marker=m3,s=ms3,c=c3,zorder =12,lw=lw3)\n",
    "    ax.scatter(mean_p1[0],mean_p1[1]+y_move,0,marker=m2,s=ms2,c=c2,zorder = 12,lw=lw2)\n",
    "\n",
    "    ax.grid(False)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    ax.set_zticks([])\n",
    "\n",
    "    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "    ax.xaxis._axinfo['axisline']['color'] = \"white\"\n",
    "    ax.xaxis._axinfo['axisline']['linewidth'] = 0\n",
    "\n",
    "    plt.xlim(xlims[0],xlims[1])\n",
    "    plt.ylim(yl[0],yl[-1])\n",
    "    ax.set_zlim(0,.2)\n",
    "    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))\n",
    "    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))\n",
    "    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))\n",
    "    ax.view_init(30, 125)\n",
    "    \n",
    "    plt.title(r\"$\\phi(x) = [x, x^2]$\")\n",
    "    plt.xlabel(r\"$\\phi_1(x)$\")\n",
    "    plt.ylabel(r\"$\\phi_2(x)$\")\n",
    "\n",
    "#plt.savefig(\"mixture3.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
