{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T18:24:32.903051Z",
     "start_time": "2025-09-03T18:24:18.428778Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from cddpm import cddpm, denoiseNet\n",
    "from tddpm import tddpm\n",
    "\n",
    "\n",
    "import configs\n",
    "import evaluate\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "from load import myMNIST\n",
    "\n",
    "# import ot\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b03aeb30864539c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T18:25:07.399637Z",
     "start_time": "2025-09-03T18:25:07.395852Z"
    }
   },
   "outputs": [],
   "source": [
    "2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b8e2cec62a89c6a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-19T17:24:20.351649Z",
     "start_time": "2025-08-19T17:24:17.207428Z"
    }
   },
   "outputs": [],
   "source": [
    "MNIST_dataset = myMNIST(blur_again=False, loadtest=True)\n",
    "\n",
    "\n",
    "\n",
    "pic_ind = 3002\n",
    "\n",
    "for i in range(11):\n",
    "    ind = i * 10\n",
    "    plt.subplot(1,11,i+1)\n",
    "    plt.imshow(MNIST_dataset.dataset[pic_ind, ind].cpu().numpy(), cmap='gray')\n",
    "    plt.axis('off')\n",
    "plt.subplots_adjust(wspace=0.5)\n",
    "plt.savefig('./saved_fig/evolve'+'.png', dpi=300, bbox_inches='tight', transparent=True)\n",
    "plt.show()\n",
    "\n",
    "for i in range(11):\n",
    "    ind = i * 10\n",
    "    plt.subplot(1,11,i+1)\n",
    "    plt.imshow(MNIST_dataset.blurred_dataset[pic_ind, ind].cpu().numpy(), cmap='gray')\n",
    "    plt.axis('off')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1075f75d1a48f9a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-19T15:11:01.141288Z",
     "start_time": "2025-08-19T15:11:01.098702Z"
    }
   },
   "outputs": [],
   "source": [
    "test_tDDPM = tddpm()\n",
    "test_cDDPM = cddpm()\n",
    "test_tFM = tddpm()\n",
    "test_cFM = cddpm()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "477e8f02b5453a78",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-19T15:11:58.188240Z",
     "start_time": "2025-08-19T15:11:58.142142Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "ckpttDDPM = torch.load('./saved_model/MNIST_tDDPM_96_500.pth', map_location=configs.device, weights_only=True)\n",
    "#Fashion_tddpmNo2_500 MNIST_tddpm_no1_200 MNIST_tddpm_FM_0p3_85_200\n",
    "ckptcDDPM = torch.load('./saved_model/MNIST_cDDPM_96_500.pth', map_location=configs.device, weights_only=True)\n",
    "ckpttFM = torch.load('./saved_model/MNIST_tFM_96_500.pth', map_location=configs.device, weights_only=True)\n",
    "ckptcFM = torch.load('./saved_model/MNIST_cFM_96_500.pth', map_location=configs.device, weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bdd6cbdf45124cb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-19T15:11:59.398658Z",
     "start_time": "2025-08-19T15:11:59.366340Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "test_cDDPM.denoise_model.load_state_dict(ckptcDDPM)\n",
    "test_tDDPM.load_denoise_model_noparallel(ckpttDDPM)\n",
    "test_tFM.denoise_model.load_state_dict(ckpttFM)\n",
    "test_cFM.denoise_model.load_state_dict(ckptcFM)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef2d96530251d03",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-19T16:20:39.548957Z",
     "start_time": "2025-08-19T15:12:01.726014Z"
    }
   },
   "outputs": [],
   "source": [
    "evaluate.MNIST_eval_batch([test_cDDPM, test_cFM, test_tDDPM, test_tFM], MNIST_dataset, model_type_list=[\"DDPM\",\"FM\", \"DDPM\", \"FM\"], save_name=\"save\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9b45b5d603607c0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
