{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import self_distillation\n",
    "import saving_loading\n",
    "import generate\n",
    "import wandb\n",
    "\n",
    "import util\n",
    "import os\n",
    "import torch\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
    "\n",
    "cwd = os.getcwd()\n",
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "!wandb login $WANDB_API_KEY\n",
    "os.environ['WANDB_NOTEBOOK_NAME'] = \"Cin_256_custom.ipynb\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "steps = 128\n",
    "# prompts = [992,2,  625, 614] # mushroom, shark, gekko (25), tuktuk, kimono\n",
    "# prompts = [538, 576, 616, 649, 719] #church, venice, rope, stonehenge, piggybank\n",
    "# prompts = [538, 616, 649, 719] #church, rope, stonehenge, piggybank\n",
    "\n",
    "prompts = [random.randrange(0, 1000, 1) for i in range(2)]\n",
    "BASE = r\"\"\n",
    "vertical = False\n",
    "\n",
    "\n",
    "\n",
    "if vertical:\n",
    "    shape = (len(prompts), 1)\n",
    "else:\n",
    "    shape = (1, len(prompts))\n",
    "# shape = [1, 4]\n",
    "\n",
    "print(prompts)\n",
    "print(shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "gradient_updates = 4000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(1 / np.append(step_sizes[1:],1))[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_sizes = np.arange(64, 0, -2)\n",
    "update_list = np.exp((1 / np.append(step_sizes[1:],1))[::-1]) / np.sum(np.exp((1 / np.append(step_sizes[1:],1))[::-1]))\n",
    "update_list = (update_list * gradient_updates /  np.append(step_sizes[1:],1)).astype(int) * 2\n",
    "\n",
    "print(update_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_sizes = np.arange(64, 0, -2)\n",
    "update_list = np.exp(1 / np.append(step_sizes[1:],1)) / np.sum(np.exp(1 / np.append(step_sizes[1:],1)))\n",
    "update_list = ( update_list* gradient_updates /  np.append(step_sizes[1:],1)).astype(int) * 2\n",
    "\n",
    "print(update_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "steps = 128\n",
    "\n",
    "\n",
    "halvings = math.floor(math.log(64)/math.log(2))\n",
    "updates_per_halving = int(gradient_updates / halvings)\n",
    "step_sizes = []\n",
    "for i in range(halvings):\n",
    "    step_sizes.append(int((steps) / (2**i)))\n",
    "update_list = []\n",
    "for i in step_sizes:\n",
    "    update_list.append(int(updates_per_halving / int(i/ 2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "update_list = np.array(update_list)\n",
    "step_sizes = np.array(step_sizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7912"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(update_list * step_sizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = 16"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIN:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path=f\"{cwd}/models/configs/cin256-v2-custom.yaml\"\n",
    "original_path=f\"{cwd}/models/cin256_original.ckpt\"\n",
    "# dsdi = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_4000\\\\iterative\\cin_DSDI_64_1e-07_4000\\\\1.pt\"\n",
    "# dsdgl = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_4000\\gradual_linear\\cin_DSDGL_64_1e-07_4000\\\\1.pt\"\n",
    "# dsdn = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_4000\\\\naive\\cin_DSDN_64_1e-07_4000\\\\64.pt\"\n",
    "# dsdgexp = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_4000\\gradual_exp\\cin_DSDGEXP_64_1e-07_4000\\\\1.pt\"\n",
    "# tsd = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_previous\\TSD\\TSD_cin_50k_1e8\\\\1.pt\"\n",
    "\n",
    "\n",
    "# config_path=f\"{cwd}/models/configs/cin256-v2-custom copy.yaml\"\n",
    "# original_path=f\"{cwd}/models/64x64_diffusion.pt\"\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## All versions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher, sampler_student = util.create_models(config_path, original_path, student=False)\n",
    "new_image, noise_list, prompt_list = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=False, model_type=\"original\", prompts=prompts)\n",
    "del teacher, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "original_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher, sampler_student = util.create_models(config_path, original_path, student=False)\n",
    "new_image, noise_list, prompt_list = generate.ablation(sampler_student, steps=64, shape=shape, noise_list = noise_list, celeb=False, model_type=\"original\", prompt_list=prompts)\n",
    "del teacher, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "original_64_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdgexp, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=False, model_type=\"DSDGEXP\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdgexp_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdi, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=False, model_type=\"DSDI\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdi_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdgl, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=False, model_type=\"DSDGL\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdgl_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdn, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=False, model_type=\"DSDN\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdn_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(tsd, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=False, model_type=\"TSD\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "tsd_img = new_image\n",
    "new_image"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final CIN Grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_labels = [\"Original 64\", \"Original\", \"TSD\", \"DSDI\", \"DSDN\", \"DSDGL\", \"DSDGEXP\"]\n",
    "grid = generate.create_horizontal_grid(original_64_img, original_img,tsd_img, dsdi_img, dsdn_img,  dsdgl_img, dsdgexp_img, celeb=False, steps=steps, font_size=20, x_labels=x_labels)\n",
    "grid"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CELEB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 343,
   "metadata": {},
   "outputs": [],
   "source": [
    "import self_distillation\n",
    "import saving_loading\n",
    "import generate\n",
    "import wandb\n",
    "import util\n",
    "import os\n",
    "import torch\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
    "\n",
    "cwd = os.getcwd()\n",
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "steps = 2\n",
    "prompts=None\n",
    "celeb = True\n",
    "\n",
    "vertical = False\n",
    "\n",
    "\n",
    "n = 2\n",
    "if vertical:\n",
    "    shape = (n, 1)\n",
    "else:\n",
    "    shape = (1, n)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CELEB run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 344,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path=f\"{cwd}/models/configs/celebahq-ldm-vq-4.yaml\"\n",
    "original_path=f\"{cwd}/models/CelebA.ckpt\"\n",
    "\n",
    "\n",
    "\n",
    "dsdi = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models\\iterative\\celeb_DSDI_64_1e-07_4000\\\\1.pt\"\n",
    "dsdgl = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models\\gradual_linear\\celeb_DSDGL_64_1e-07_4000\\\\1.pt\"\n",
    "dsdn = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_previous\\\\final_versions\\celeb\\DSDN\\\\64.pt\"\n",
    "dsdgexp = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_4000\\gradual_exp\\celeb_DSDGEXP_64_1e-07_4000\\\\1.pt\"\n",
    "tsd = f\"{BASE}\\Diffusion_Thesis\\cin_256\\data\\\\trained_models_previous\\\\final_versions\\celeb\\TSD\\\\1.pt\"\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## All versions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher, sampler_student = util.create_models(config_path, original_path, student=False)\n",
    "new_image, noise_list, prompt_list = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=celeb, model_type=\"original\", prompts=prompts)\n",
    "del teacher, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "original_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher, sampler_student = util.create_models(config_path, original_path, student=False)\n",
    "new_image, noise_list, prompt_list = generate.ablation(sampler_student, steps=64, shape=shape, noise_list = noise_list, celeb=celeb, model_type=\"original\", prompts=prompts)\n",
    "del teacher, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "original_64_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdgexp, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=celeb, model_type=\"DSDGEXP\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdgexp_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdi, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=celeb, model_type=\"DSDI\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdi_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdgl, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=celeb, model_type=\"DSDGL\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdgl_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(dsdn, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=celeb, model_type=\"DSDN\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "dsdn_img = new_image\n",
    "new_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student, sampler_student, optimizer, scheduler = util.load_trained(tsd, config_path)\n",
    "new_image, _, _ = generate.ablation(sampler_student, steps=steps, shape=shape, celeb=celeb, model_type=\"TSD\", noise_list=noise_list, prompt_list=prompts)\n",
    "del optimizer, scheduler, student, sampler_student\n",
    "torch.cuda.empty_cache()\n",
    "tsd_img = new_image\n",
    "new_image"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final CELEB Grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_labels = [\"Original 64\", \"Original\", \"TSD\", \"DSDI\", \"DSDN\", \"DSDGL\", \"DSDGEXP\"]\n",
    "# x_labels=None\n",
    "grid = generate.create_horizontal_grid(original_64_img, original_img,tsd_img, dsdi_img, dsdn_img,  dsdgl_img, dsdgexp_img, celeb=celeb, steps=steps, x_labels=x_labels, font_size=20)\n",
    "grid"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DSD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
