{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms.functional import to_pil_image\n",
    "import matplotlib.pyplot as plt\n",
    "from scoremodel import Model, AnnealedLangevinDynamic\n",
    "from ddpm_conditional import Diffusion, generate_random_tensor\n",
    "import torch\n",
    "import os\n",
    "\n",
    "from utils import convert_to_grayscale, plot_images, wasserstein_distance, get_data_conditional, save_images\n",
    "from projection import Projection, ProbabilisticDamage\n",
    "\n",
    "from torchmetrics.image.fid import FrechetInceptionDistance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# epsilon of step size\n",
    "eps = 1.5e-5\n",
    "\n",
    "# sigma min and max of Langevin dynamic\n",
    "sigma_min = 0.005\n",
    "sigma_max = 10\n",
    "\n",
    "# Langevin step size and Annealed size\n",
    "n_steps = 10\n",
    "annealed_step = 25\n",
    "\n",
    "sample_size = 10\n",
    "\n",
    "device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming you have already defined your UNet model class and device\n",
    "model = Model(device, n_steps, sigma_min, sigma_max)\n",
    "dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)\n",
    "\n",
    "# Load the checkpoint\n",
    "checkpoint_path = './models/score_model/ckpt.pt'\n",
    "\n",
    "# Load the weights from the checkpoint into the model\n",
    "model.load_state_dict(torch.load(checkpoint_path, map_location=device))\n",
    "\n",
    "# Make sure to set your model to evaluation or training mode after loading the state dict\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def give_me_images(num_images, const, only_final, i):\n",
    "    \n",
    "    dynamic.annealed_step = 100\n",
    "    dynamic.pgd = True\n",
    "    dynamic.post_process = False\n",
    "    # dynamic.damage_projection = True\n",
    "    dynamic.damage_projection = False\n",
    "    dynamic.damage_constraints = const\n",
    "    dynamic.start = i\n",
    "    dynamic.ub, dynamic.lb = const[0], const[1]\n",
    "    \n",
    "    sample = None\n",
    "    \n",
    "\n",
    "    sample = dynamic.sampling(num_images, only_final)\n",
    "    \n",
    "    \n",
    "    \n",
    "    print(\"| Projected Black Pixel  (%) : \",(sample < (0.0)).float().mean() * 100)\n",
    "    \n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    \n",
    "\n",
    "        variation = 2\n",
    "                            # ub, lb\n",
    "        porosity = [(0.5, 0.48), (0.41, 0.39), (0.31, 0.29), (0.21, 0.19), (0.11, 0.09)]\n",
    "        sample  = give_me_images(100, (0.43, 0.37), True, 6) \n",
    "        \n",
    "        print(sample.max(), sample.min())\n",
    "        plot_images(sample[:8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchmetrics",
   "language": "python",
   "name": "torchmetrics"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
