{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:37:21.578714Z",
     "start_time": "2025-09-15T20:37:02.623601Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from fontTools.unicodedata import script\n",
    "\n",
    "from cddpm import cddpm, denoiseNet\n",
    "from tddpm import tddpm\n",
    "import configs\n",
    "# import evaluate\n",
    "import numpy as np\n",
    "#\n",
    "#\n",
    "from load import myMNIST\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaee456e63669e76",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:37:21.612954Z",
     "start_time": "2025-09-15T20:37:21.609402Z"
    }
   },
   "outputs": [],
   "source": [
    "2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "414b7d6f7638c9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:37:27.159251Z",
     "start_time": "2025-09-15T20:37:21.871713Z"
    }
   },
   "outputs": [],
   "source": [
    "MNIST_dataset = myMNIST(blur_again=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25973a22c969b818",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-06T18:11:13.442903Z",
     "start_time": "2025-08-06T18:11:13.440169Z"
    }
   },
   "outputs": [],
   "source": [
    "MNIST_dataset.blurred_dataset.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48616eccdffc2b6e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-06T18:11:14.599375Z",
     "start_time": "2025-08-06T18:11:14.268852Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "pic_ind = 3000\n",
    "\n",
    "for i in range(11):\n",
    "    ind = i * 10\n",
    "    plt.subplot(1,11,i+1)\n",
    "    plt.imshow(MNIST_dataset.dataset[pic_ind, ind].cpu().numpy(), cmap='gray')\n",
    "    plt.axis('off')\n",
    "plt.show()\n",
    "\n",
    "for i in range(11):\n",
    "    ind = i * 10\n",
    "    plt.subplot(1,11,i+1)\n",
    "    plt.imshow(MNIST_dataset.blurred_dataset[pic_ind, ind].cpu().numpy(), cmap='gray')\n",
    "    plt.axis('off')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d1a48cef4a6ae76",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-06T18:11:19.122181Z",
     "start_time": "2025-08-06T18:11:19.119402Z"
    }
   },
   "outputs": [],
   "source": [
    "MNIST_dataset.dataset.__len__()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0e73fa8ab5c2cad",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-06T18:11:20.231869Z",
     "start_time": "2025-08-06T18:11:20.203897Z"
    }
   },
   "outputs": [],
   "source": [
    "test_tddpm = tddpm()\n",
    "test_cddpmNo1 = cddpm()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "807d006f2ffaf6c",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bd31c4a032ec503",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-06T18:45:17.226386Z",
     "start_time": "2025-08-06T18:11:21.478451Z"
    }
   },
   "outputs": [],
   "source": [
    "# ckpt_enc = torch.load('./saved_model/pre_trained_info_enc_100.pth')\n",
    "test_tddpm = tddpm()\n",
    "test_tddpm.train(MNIST_dataset,model_type=\"DDPM\", save_name='MNIST_tDDPM_0p2_90', pretrained_info_enc_ckpt=None)\n",
    "test_tddpm = tddpm()\n",
    "test_tddpm.train(MNIST_dataset,model_type=\"FM\", save_name='MNIST_tFM_0p2_90', pretrained_info_enc_ckpt=None)\n",
    "test_cddpmNo1 = cddpm()\n",
    "test_cddpmNo1.train(MNIST_dataset, save_name='MNIST_cDDPM_0p2_90', rep_len=8, model_type = 'DDPM')\n",
    "test_cddpmNo1 = cddpm()\n",
    "test_cddpmNo1.train(MNIST_dataset, save_name='MNIST_cFM_0p2_90', rep_len=8, model_type = 'FM')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
