{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import time\n",
    "import io\n",
    "import copy\n",
    "import os\n",
    "\n",
    "import pickle\n",
    "from pylab import *\n",
    "\n",
    "from torchvision import transforms\n",
    "\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "from diffusers import DDPMScheduler\n",
    "from statsmodels.tsa.stattools import acf\n",
    "\n",
    "from diffusers import EulerDiscreteScheduler\n",
    "from timeit import default_timer as timer\n",
    "\n",
    "import sys\n",
    "\n",
    "current_path = os.getcwd()\n",
    "two_folders_up = os.path.join(current_path, '..', '..')\n",
    "desired_folder = os.path.join(two_folders_up, 'SeqDiff')\n",
    "absolute_desired_folder = os.path.abspath(desired_folder)\n",
    "sys.path.insert(0, absolute_desired_folder)\n",
    "\n",
    "from dynamicsdiffusion import TemporalUnet, TemporalUnetEnergy, GaussianDiffusion, Trainer, count_parameters\n",
    "\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler_config_euler = {\n",
    "    \"num_train_timesteps\": 1000,\n",
    "    \"beta_start\": 0.0001,\n",
    "    \"beta_end\": 0.02,\n",
    "    \"beta_schedule\": \"linear\",\n",
    "    \"trained_betas\": None,\n",
    "    \"variance_type\": \"fixed_large\",\n",
    "    \"clip_sample\": True,\n",
    "    \"prediction_type\": \"epsilon\",\n",
    "    \"_class_name\": \"DDPMScheduler\",\n",
    "    \"_diffusers_version\": \"0.1.1\",\t\n",
    "}\n",
    "\n",
    "scheduler = EulerDiscreteScheduler.from_config(scheduler_config_euler)\n",
    "\n",
    "horizon = [64]\n",
    "transition_dim = 30 \n",
    "\n",
    "model = TemporalUnetEnergy(horizon = horizon[0], transition_dim = transition_dim, cond_dim = None, dim = 96, dim_mults = (1, 2), kernel_size = 3, attention=False).to(device)\n",
    "count_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#same but in half precision\n",
    "num_inference_steps = 1\n",
    "samples = torch.randn(size=(5000, 128, 30), dtype=torch.float16, device=device)\n",
    "model = model.half()\n",
    "scheduler.set_timesteps(num_inference_steps)\n",
    "with torch.no_grad():\n",
    "    with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n",
    "        start = timer()\n",
    "        for t in scheduler.timesteps:\n",
    "            model_output = model(samples, None, t.repeat((samples.shape[0],)).to(device).long())\n",
    "            samples = scheduler.step(model_output, t, samples, 1).prev_sample\n",
    "end = timer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_diff = end - start\n",
    "n_steps = 30\n",
    "total_time = time_diff * n_steps\n",
    "total_time_days = total_time / 86400\n",
    "ns_total = samples.shape[0] * samples.shape[1] * 2e-5\n",
    "ns_per_day = ns_total / total_time_days"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ns_per_day"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "639d8a7c3e620b1d142eea4deabde5aac9ed3b21a6e651e4622d69fbdac2ed0a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
