{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from loss import log_t_normalizing_const\n",
    "from univariate.sampling import t_density, t_density_contour\n",
    "\n",
    "def visualize_density(model_title_list, model_gen_list, \n",
    "                      K, sample_nu_list, sample_mu_list, sample_var_list, ratio_list, xlim) :\n",
    "    model_gen_list = [gen[torch.isfinite(gen)].cpu().numpy() for gen in model_gen_list]\n",
    "\n",
    "    M = len(model_gen_list)\n",
    "    input = np.arange(-xlim * 100, xlim * 100 + 1) * 0.01\n",
    "    contour = t_density_contour(input, K, sample_nu_list, sample_mu_list, sample_var_list, ratio_list).squeeze().numpy()\n",
    "\n",
    "    # plot\n",
    "    fig = plt.figure(figsize = (3.5 * M, 7))\n",
    "\n",
    "    for m in range(M) : \n",
    "        ax = fig.add_subplot(2,M,m+1)\n",
    "        plt.plot(input, contour, color='black')\n",
    "        plt.hist(model_gen_list[m], bins = 100, range = [-10, 10], density=True, alpha = 0.5, color='dodgerblue')\n",
    "        plt.xlim(-10, 10)\n",
    "        plt.title(f'{model_title_list[m]}')\n",
    "\n",
    "        ax = fig.add_subplot(2,M,M+m+1)\n",
    "        plt.plot(input, contour, color='black')\n",
    "        plt.hist(model_gen_list[m], bins = 100, range = [-xlim, xlim], density=True, alpha = 0.5, color='dodgerblue')\n",
    "        plt.xlim(-xlim, xlim)\n",
    "        plt.yscale(\"log\")\n",
    "        plt.ylim(1e-6, 1)\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirname = \"results\"\n",
    "K=2\n",
    "sample_nu_list = [5.0, 5.0]\n",
    "sample_mu_list = [mu * torch.ones(1) for mu in [-2.0, 2.0]]\n",
    "sample_var_list = [var * torch.ones(1,1) for var in [1.0, 1.0]]\n",
    "ratio_list = [0.6, 0.4]\n",
    "xlim = 12\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from util import make_reproducibility\n",
    "from mmd import make_masking, mmd_linear, mmd_linear_bootstrap_test\n",
    "\n",
    "file_list = np.asarray(os.listdir(f'./1D_results/{dirname}'))\n",
    "csv_list = file_list[np.where(['.csv' in name for name in  file_list])[0]]\n",
    "csv_list = np.asarray([name[0:-4] for name in csv_list])\n",
    "csv_list = csv_list[np.where(csv_list != 'test_data')[0]]\n",
    "csv_selected = [\n",
    "    't3VAE_nu_9.0', 't3VAE_nu_12.0', 't3VAE_nu_15.0', 't3VAE_nu_18.0', 't3VAE_nu_21.0', \n",
    "    'VAE', 'betaVAE_0.1', 't-VAE', 'Disentangled_VAE_nu_9.0', 'VAE-st_nu_12.0'\n",
    "]\n",
    "\n",
    "name_list = [\n",
    "    r't3VAE ($\\nu=9.0$)', r't3VAE ($\\nu=12.0$)', r't3VAE ($\\nu=15.0$)', r't3VAE ($\\nu=18.0$)', r't3VAE ($\\nu=21.0$)', \n",
    "    'VAE', r'$\\beta$-VAE ($\\beta = 0.1$)', 'Student-t VAE', r'DE-VAE ($\\nu = 9.0$)', r'VAE-st ($\\nu = 12.0$)'\n",
    "]\n",
    "\n",
    "gen_list = [np.asarray(pd.read_csv(f'./1D_results/{dirname}/{csv_name}.csv', header = None)) for csv_name in csv_selected]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "M = len(gen_list)\n",
    "input = np.arange(-xlim * 100, xlim * 100 + 1) * 0.01\n",
    "contour = t_density_contour(input, K, sample_nu_list, sample_mu_list, sample_var_list, ratio_list).squeeze().numpy()\n",
    "\n",
    "# plot\n",
    "fig = plt.figure(figsize = (4 * M, 8))\n",
    "\n",
    "for m in range(M) : \n",
    "    ax = fig.add_subplot(2,5,m + 1)\n",
    "    plt.plot(input, contour, color='black')\n",
    "    plt.hist(gen_list[m], bins = 100, range = [-xlim, xlim], density=True, alpha = 0.5, color='dodgerblue')\n",
    "    plt.xlim(-xlim, xlim)\n",
    "    plt.yscale(\"log\")\n",
    "    plt.ylim(1e-6, 1)\n",
    "    plt.rc('xtick', labelsize=15)  \n",
    "    plt.rc('ytick', labelsize=15)  \n",
    "    plt.title(f'{name_list[m]}', fontdict = {'fontsize' : 24})\n",
    "fig.subplots_adjust(wspace=0.2, hspace = 0.3)\n",
    "fig.savefig('univariate_log_histogram.png')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
