{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "b673ca8c",
      "metadata": {
        "id": "b673ca8c"
      },
      "outputs": [],
      "source": [
        "import math, os, time, random\n",
        "from dataclasses import dataclass\n",
        "from typing import Optional, Tuple\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from torchvision import datasets, transforms, utils as tvu\n",
        "\n",
        "from torchmetrics.image.fid import FrechetInceptionDistance\n",
        "from torchmetrics.image.inception import InceptionScore\n",
        "from torchmetrics.functional.image.ssim import structural_similarity_index_measure as ssim\n",
        "\n",
        "from nn import *\n",
        "from utils import *"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "6544381c",
      "metadata": {
        "id": "6544381c"
      },
      "outputs": [],
      "source": [
        "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "# Teacher training\n",
        "train_cfg = TrainCfg(\n",
        "    data_root='./data',\n",
        "    epochs=10,                 # adjust\n",
        "    batch_size=128,\n",
        "    lr=2e-4,\n",
        "    timesteps=1000,\n",
        "    base_ch=48,\n",
        "    time_emb_dim=96,\n",
        "    sigma_c=0.9,              # std for c = x + N(0, sigma_c^2)\n",
        "    cf_drop_prob=0.2,\n",
        "    seed=42,\n",
        "    device=DEVICE\n",
        ")\n",
        "\n",
        "# Distillation\n",
        "distill_cfg = DistillCfg(\n",
        "    timesteps=1000,\n",
        "    teacher_steps=50,         # number of teacher DDIM steps used to build the target\n",
        "    student_steps=25,         # must be teacher_steps // 2\n",
        "    epochs=50,\n",
        "    batch_size=128,\n",
        "    lr=1e-4,\n",
        "    sigma_c=0.9,\n",
        "    data_root='./data',\n",
        "    seed=42,\n",
        "    device=DEVICE,\n",
        "    cfg_scale_teacher=2.0     # CFG scale used in distillation target (teacher)\n",
        ")\n",
        "\n",
        "# Evaluation args object (to keep your evaluate_model signature)\n",
        "class EvalArgs: pass\n",
        "eval_args = EvalArgs()\n",
        "eval_args.guidance_scale = 2.0    # CFG at sampling time\n",
        "eval_args.timesteps = train_cfg.timesteps  # schedule timesteps (always 1000)\n",
        "eval_args.ddim_steps = distill_cfg.student_steps  # only used internally by our sampler hook\n",
        "eval_num_batches = 50              # how many batches to evaluate on"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "222c3356",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "222c3356",
        "outputId": "e1c80b68-f744-40d1-dd9d-46404e5d8d60"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 9.91M/9.91M [00:00<00:00, 43.1MB/s]\n",
            "100%|██████████| 28.9k/28.9k [00:00<00:00, 1.11MB/s]\n",
            "100%|██████████| 1.65M/1.65M [00:00<00:00, 10.1MB/s]\n",
            "100%|██████████| 4.54k/4.54k [00:00<00:00, 10.9MB/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Teacher] ep 1/10 step 200 loss=0.0625\n",
            "[Teacher] ep 1/10 step 400 loss=0.0343\n",
            "[Teacher] ep 2/10 step 600 loss=0.0346\n",
            "[Teacher] ep 2/10 step 800 loss=0.0284\n",
            "[Teacher] ep 3/10 step 1000 loss=0.0256\n",
            "[Teacher] ep 3/10 step 1200 loss=0.0260\n",
            "[Teacher] ep 3/10 step 1400 loss=0.0217\n",
            "[Teacher] ep 4/10 step 1600 loss=0.0230\n",
            "[Teacher] ep 4/10 step 1800 loss=0.0246\n",
            "[Teacher] ep 5/10 step 2000 loss=0.0245\n",
            "[Teacher] ep 5/10 step 2200 loss=0.0299\n",
            "[Teacher] ep 6/10 step 2400 loss=0.0311\n",
            "[Teacher] ep 6/10 step 2600 loss=0.0214\n",
            "[Teacher] ep 6/10 step 2800 loss=0.0273\n",
            "[Teacher] ep 7/10 step 3000 loss=0.0241\n",
            "[Teacher] ep 7/10 step 3200 loss=0.0202\n",
            "[Teacher] ep 8/10 step 3400 loss=0.0180\n",
            "[Teacher] ep 8/10 step 3600 loss=0.0228\n",
            "[Teacher] ep 9/10 step 3800 loss=0.0157\n",
            "[Teacher] ep 9/10 step 4000 loss=0.0177\n",
            "[Teacher] ep 9/10 step 4200 loss=0.0206\n",
            "[Teacher] ep 10/10 step 4400 loss=0.0161\n",
            "[Teacher] ep 10/10 step 4600 loss=0.0196\n"
          ]
        }
      ],
      "source": [
        "teacher, schedule = train_teacher(train_cfg, teacher_model = SmallUNet)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "0a0da0eb",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0a0da0eb",
        "outputId": "d2737b0b-aeca-42e7-f586-c05f91e82ab4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Distill] ep 1/50 it 200/469 loss=0.0124\n",
            "[Distill] ep 1/50 it 400/469 loss=0.0101\n",
            "[Distill] ep 2/50 it 200/469 loss=0.0090\n",
            "[Distill] ep 2/50 it 400/469 loss=0.0090\n",
            "[Distill] ep 3/50 it 200/469 loss=0.0079\n",
            "[Distill] ep 3/50 it 400/469 loss=0.0075\n",
            "[Distill] ep 4/50 it 200/469 loss=0.0056\n",
            "[Distill] ep 4/50 it 400/469 loss=0.0058\n",
            "[Distill] ep 5/50 it 200/469 loss=0.0050\n",
            "[Distill] ep 5/50 it 400/469 loss=0.0051\n",
            "[Distill] ep 6/50 it 200/469 loss=0.0047\n",
            "[Distill] ep 6/50 it 400/469 loss=0.0048\n",
            "[Distill] ep 7/50 it 200/469 loss=0.0043\n",
            "[Distill] ep 7/50 it 400/469 loss=0.0043\n",
            "[Distill] ep 8/50 it 200/469 loss=0.0044\n",
            "[Distill] ep 8/50 it 400/469 loss=0.0039\n",
            "[Distill] ep 9/50 it 200/469 loss=0.0035\n",
            "[Distill] ep 9/50 it 400/469 loss=0.0036\n",
            "[Distill] ep 10/50 it 200/469 loss=0.0035\n",
            "[Distill] ep 10/50 it 400/469 loss=0.0035\n",
            "[Distill] ep 11/50 it 200/469 loss=0.0033\n",
            "[Distill] ep 11/50 it 400/469 loss=0.0033\n",
            "[Distill] ep 12/50 it 200/469 loss=0.0031\n",
            "[Distill] ep 12/50 it 400/469 loss=0.0033\n",
            "[Distill] ep 13/50 it 200/469 loss=0.0029\n",
            "[Distill] ep 13/50 it 400/469 loss=0.0030\n",
            "[Distill] ep 14/50 it 200/469 loss=0.0029\n",
            "[Distill] ep 14/50 it 400/469 loss=0.0027\n",
            "[Distill] ep 15/50 it 200/469 loss=0.0030\n",
            "[Distill] ep 15/50 it 400/469 loss=0.0027\n",
            "[Distill] ep 16/50 it 200/469 loss=0.0027\n",
            "[Distill] ep 16/50 it 400/469 loss=0.0026\n",
            "[Distill] ep 17/50 it 200/469 loss=0.0027\n",
            "[Distill] ep 17/50 it 400/469 loss=0.0024\n",
            "[Distill] ep 18/50 it 200/469 loss=0.0026\n",
            "[Distill] ep 18/50 it 400/469 loss=0.0026\n",
            "[Distill] ep 19/50 it 200/469 loss=0.0024\n",
            "[Distill] ep 19/50 it 400/469 loss=0.0025\n",
            "[Distill] ep 20/50 it 200/469 loss=0.0024\n",
            "[Distill] ep 20/50 it 400/469 loss=0.0023\n",
            "[Distill] ep 21/50 it 200/469 loss=0.0024\n",
            "[Distill] ep 21/50 it 400/469 loss=0.0023\n",
            "[Distill] ep 22/50 it 200/469 loss=0.0023\n",
            "[Distill] ep 22/50 it 400/469 loss=0.0023\n",
            "[Distill] ep 23/50 it 200/469 loss=0.0023\n",
            "[Distill] ep 23/50 it 400/469 loss=0.0021\n",
            "[Distill] ep 24/50 it 200/469 loss=0.0022\n",
            "[Distill] ep 24/50 it 400/469 loss=0.0022\n",
            "[Distill] ep 25/50 it 200/469 loss=0.0020\n",
            "[Distill] ep 25/50 it 400/469 loss=0.0022\n",
            "[Distill] ep 26/50 it 200/469 loss=0.0021\n",
            "[Distill] ep 26/50 it 400/469 loss=0.0021\n",
            "[Distill] ep 27/50 it 200/469 loss=0.0021\n",
            "[Distill] ep 27/50 it 400/469 loss=0.0020\n",
            "[Distill] ep 28/50 it 200/469 loss=0.0022\n",
            "[Distill] ep 28/50 it 400/469 loss=0.0020\n",
            "[Distill] ep 29/50 it 200/469 loss=0.0021\n",
            "[Distill] ep 29/50 it 400/469 loss=0.0019\n",
            "[Distill] ep 30/50 it 200/469 loss=0.0019\n",
            "[Distill] ep 30/50 it 400/469 loss=0.0019\n",
            "[Distill] ep 31/50 it 200/469 loss=0.0019\n",
            "[Distill] ep 31/50 it 400/469 loss=0.0019\n",
            "[Distill] ep 32/50 it 200/469 loss=0.0018\n",
            "[Distill] ep 32/50 it 400/469 loss=0.0018\n",
            "[Distill] ep 33/50 it 200/469 loss=0.0018\n",
            "[Distill] ep 33/50 it 400/469 loss=0.0018\n",
            "[Distill] ep 34/50 it 200/469 loss=0.0018\n",
            "[Distill] ep 34/50 it 400/469 loss=0.0019\n",
            "[Distill] ep 35/50 it 200/469 loss=0.0017\n",
            "[Distill] ep 35/50 it 400/469 loss=0.0018\n",
            "[Distill] ep 36/50 it 200/469 loss=0.0018\n",
            "[Distill] ep 36/50 it 400/469 loss=0.0018\n",
            "[Distill] ep 37/50 it 200/469 loss=0.0018\n",
            "[Distill] ep 37/50 it 400/469 loss=0.0017\n",
            "[Distill] ep 38/50 it 200/469 loss=0.0017\n",
            "[Distill] ep 38/50 it 400/469 loss=0.0017\n",
            "[Distill] ep 39/50 it 200/469 loss=0.0016\n",
            "[Distill] ep 39/50 it 400/469 loss=0.0018\n",
            "[Distill] ep 40/50 it 200/469 loss=0.0017\n",
            "[Distill] ep 40/50 it 400/469 loss=0.0016\n",
            "[Distill] ep 41/50 it 200/469 loss=0.0017\n",
            "[Distill] ep 41/50 it 400/469 loss=0.0016\n",
            "[Distill] ep 42/50 it 200/469 loss=0.0016\n",
            "[Distill] ep 42/50 it 400/469 loss=0.0015\n",
            "[Distill] ep 43/50 it 200/469 loss=0.0016\n",
            "[Distill] ep 43/50 it 400/469 loss=0.0015\n",
            "[Distill] ep 44/50 it 200/469 loss=0.0016\n",
            "[Distill] ep 44/50 it 400/469 loss=0.0016\n",
            "[Distill] ep 45/50 it 200/469 loss=0.0015\n",
            "[Distill] ep 45/50 it 400/469 loss=0.0015\n",
            "[Distill] ep 46/50 it 200/469 loss=0.0015\n",
            "[Distill] ep 46/50 it 400/469 loss=0.0016\n",
            "[Distill] ep 47/50 it 200/469 loss=0.0014\n",
            "[Distill] ep 47/50 it 400/469 loss=0.0015\n",
            "[Distill] ep 48/50 it 200/469 loss=0.0015\n",
            "[Distill] ep 48/50 it 400/469 loss=0.0014\n",
            "[Distill] ep 49/50 it 200/469 loss=0.0015\n",
            "[Distill] ep 49/50 it 400/469 loss=0.0014\n",
            "[Distill] ep 50/50 it 200/469 loss=0.0014\n",
            "[Distill] ep 50/50 it 400/469 loss=0.0015\n"
          ]
        }
      ],
      "source": [
        "student_cnn = distill_student(teacher, schedule, distill_cfg, student_model = StudentDenoiseCNN)\n",
        "\n",
        "student_wrapped_cnn = DistilledWrapper(student_cnn, ddim_steps=distill_cfg.student_steps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "4271ed47",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4271ed47",
        "outputId": "ee92c3d0-50bc-4e36-d769-d083d0609276"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Downloading: \"https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth\" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth\n",
            "100%|██████████| 91.2M/91.2M [00:00<00:00, 243MB/s]\n",
            "/usr/local/lib/python3.12/dist-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `InceptionScore` will save all extracted features in buffer. For large datasets this may lead to large memory footprint.\n",
            "  warnings.warn(*args, **kwargs)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Distilled student metrics: {'MSE': 0.08602545097470284, 'PSNR': 10.658369887252679, 'SSIM': 0.5083821958303452, 'FID': 3.5713491439819336, 'Inception Score (mean)': 2.0830953121185303, 'GenTime (s/batch)': 0.11877147197723388, 'GenTime (s/img)': 0.0009279021248221397}\n"
          ]
        }
      ],
      "source": [
        "import torch_fidelity\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5,), (0.5,))  # [-1,1]\n",
        "])\n",
        "val_ds = MNISTNoisyConditionDataset(root='./data', train=False, download=True, sigma_c=train_cfg.sigma_c, transform=transform)\n",
        "val_loader = DataLoader(val_ds, batch_size=train_cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True)\n",
        "\n",
        "# Evaluate the distilled student (few-step DDIM is invoked under the hood)\n",
        "metrics_student_cnn = evaluate_model(student_wrapped_cnn, val_loader, schedule, torch.device(DEVICE), eval_args,\n",
        "                                 num_batches=eval_num_batches, model_type=\"ddpm\")\n",
        "print(\"Distilled student metrics:\", metrics_student_cnn)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "52cd296a",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "52cd296a",
        "outputId": "630d1b2c-65a8-485f-e725-965141636b17"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import torch\n",
        "\n",
        "SAVE_DIR = \"./saved_models\"\n",
        "os.makedirs(SAVE_DIR, exist_ok=True)\n",
        "\n",
        "# ---- Save teacher ----\n",
        "teacher_path = os.path.join(SAVE_DIR, \"teacher.pth\")\n",
        "torch.save(teacher.state_dict(), teacher_path)\n",
        "print(f\"Saved teacher → {teacher_path}\")\n",
        "\n",
        "# ---- Save raw student ----\n",
        "student_path = os.path.join(SAVE_DIR, \"student_cnn_raw.pth\")\n",
        "torch.save(student_cnn.state_dict(), student_path)\n",
        "print(f\"Saved student_raw → {student_path}\")\n",
        "\n",
        "# ---- Save wrapped student ----\n",
        "wrapped_path = os.path.join(SAVE_DIR, \"student_cnn_wrapped.pth\")\n",
        "torch.save(student_wrapped_cnn.state_dict(), wrapped_path)\n",
        "print(f\"Saved student_wrapped → {wrapped_path}\")\n",
        "\n",
        "# Optionally save config\n",
        "cfg_path = os.path.join(SAVE_DIR, \"distill_cfg.pt\")\n",
        "torch.save(distill_cfg, cfg_path)\n",
        "print(f\"Saved distillation config → {cfg_path}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1fa9c346",
      "metadata": {
        "id": "1fa9c346"
      },
      "outputs": [],
      "source": [
        "# Load teacher\n",
        "distill_cfg = torch.load(\"./saved_models/distill_cfg.pt\")\n",
        "\n",
        "teacher = SmallUNet(...)   # must match architecture\n",
        "teacher.load_state_dict(torch.load(\"./saved_models/teacher.pth\"))\n",
        "teacher.eval()\n",
        "\n",
        "# Load raw student\n",
        "student_cnn = StudentDenoiseCNN(...)\n",
        "student_cnn.load_state_dict(torch.load(\"./saved_models/student_cnn_raw.pth\"))\n",
        "student_cnn.eval()\n",
        "\n",
        "# Load wrapped student (must create wrapper manually)\n",
        "student_wrapped_cnn = DistilledWrapper(student_cnn, ddim_steps=distill_cfg.student_steps)\n",
        "student_wrapped_cnn.load_state_dict(torch.load(\"./saved_models/student_cnn_wrapped.pth\"))\n",
        "student_wrapped_cnn.eval()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
