{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0916941b-5e00-4803-b1a9-6aac02444cc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "\n",
    "data_name = 'cifar10'\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "save_result_path = \"./result/{}\".format(data_name)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79474e08-a2a4-4342-b773-6fb51e6c33cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator_list = []\n",
    "count = 0\n",
    "\n",
    "skip = 'quad'\n",
    "\n",
    "for image_index in range(10000):\n",
    "\n",
    "    num_ddim_steps_1 = 50\n",
    "    \n",
    "    image_path = '{}/S_{}/image_{:d}'.format(save_result_path, num_ddim_steps_1, image_index)\n",
    "    if skip == 'linear':\n",
    "        noise_filename1 = '{}/model_{}_intermediates.pt'.format(image_path, model_name_1)\n",
    "        jac_filename1 = '{}/model_{}_jac_logdet.pt'.format(image_path, model_name_1)\n",
    "    else:\n",
    "        noise_filename1 = '{}/model_{}_{}_intermediates.pt'.format(image_path, model_name_1, skip)\n",
    "        jac_filename1 = '{}/model_{}_{}_jac_logdet.pt'.format(image_path, model_name_1, skip)\n",
    "\n",
    "    num_ddim_steps_2 = 100\n",
    "    image_path = '{}/S_{}/image_{:d}'.format(save_result_path, num_ddim_steps_2, image_index)\n",
    "    if skip == 'linear':\n",
    "        noise_filename2 = '{}/model_{}_intermediates.pt'.format(image_path, model_name_1)\n",
    "        jac_filename2 = '{}/model_{}_jac_logdet.pt'.format(image_path, model_name_1)\n",
    "    else:\n",
    "        noise_filename2 = '{}/model_{}_{}_intermediates.pt'.format(image_path, model_name_1, skip)\n",
    "        jac_filename2 = '{}/model_{}_{}_jac_logdet.pt'.format(image_path, model_name_1, skip)\n",
    "\n",
    "    # load saved results to compute the estimator\n",
    "    if os.path.exists(noise_filename1) and os.path.exists(noise_filename2): \n",
    "        f = open(noise_filename1, 'rb')\n",
    "        intermediates1 = pickle.load(f)\n",
    "        f.close()\n",
    "        f = open(jac_filename1, 'rb')\n",
    "        jac_logdet_model1 = pickle.load(f)\n",
    "        f.close()\n",
    "        \n",
    "        f = open(noise_filename2, 'rb')\n",
    "        intermediates2 = pickle.load(f)\n",
    "        f.close()\n",
    "        f = open(jac_filename2, 'rb')\n",
    "        jac_logdet_model2 = pickle.load(f)\n",
    "        f.close()\n",
    "        \n",
    "        estimator = 0.5 * intermediates2[-1].pow(2).sum() - 0.5 * intermediates1[-1].pow(2).sum() + jac_logdet_model1 - jac_logdet_model2\n",
    "        estimator_list.append(estimator.item())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cad45d0d-74e6-4ca2-95a6-1bc72c1bcc98",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator_list = np.array(estimator_list)\n",
    "estimator = estimator_list.mean()\n",
    "CI_lower_bound = estimator_list.mean() - z_half_alpha * estimator_list.std() / np.sqrt(len(estimator_list))\n",
    "CI_upper_bound = estimator_list.mean() + z_half_alpha * estimator_list.std() / np.sqrt(len(estimator_list))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21c634de-5c57-4b95-9148-438b4ea2a822",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
