{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision.transforms.functional import to_pil_image\n",
    "import matplotlib.pyplot as plt\n",
    "from modules import UNet_conditional\n",
    "from ddpm_conditional import Diffusion, generate_random_tensor\n",
    "import torch\n",
    "\n",
    "from utils import convert_to_grayscale, plot_images, wasserstein_distance, get_data_conditional, save_images, normalize_sample\n",
    "from projection import Projection\n",
    "\n",
    "\n",
    "device = 'cuda'\n",
    "diffusion = Diffusion(img_size=64, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming you have already defined your UNet model class and device\n",
    "model = UNet_conditional(c_in=1, c_out=1).to(device)\n",
    "\n",
    "# Load the checkpoint\n",
    "checkpoint_path = '/path/to/dir/models/conditional/ckpt.pt'  # Change this to your checkpoint file's path\n",
    "\n",
    "# Load the weights from the checkpoint into the model\n",
    "model.load_state_dict(torch.load(checkpoint_path, map_location=device))\n",
    "\n",
    "# Make sure to set your model to evaluation or training mode after loading the state dict\n",
    "model.eval()  # If you are using the model for inference/evaluation\n",
    "# model.train()  # If you are planning to continue training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def give_me_images(model, diffusion, generate_random_tensor, device, num_images, data_dir):\n",
    "    labels = generate_random_tensor(data_dir)\n",
    "    print(labels)\n",
    "    labels = labels.to(device)\n",
    "    sampled_images = diffusion.sample(model, n=num_images, labels=labels)  # Use 'num_images' for dynamic sampling\n",
    "    \n",
    "    # Adjust the number of columns to the number of images to be displayed\n",
    "    num_columns = num_images\n",
    "    fig, axes = plt.subplots(nrows=1, ncols=num_columns, figsize=(num_images * 3.75, 3))  # Dynamic figsize based on 'num_images'\n",
    "\n",
    "    # Ensure axes is iterable (i.e., a list) even when there's only one column\n",
    "    if num_columns == 1:\n",
    "        axes = [axes]\n",
    "    \n",
    "    for i, img_tensor in enumerate(sampled_images):\n",
    "        img = to_pil_image(img_tensor.cpu())  # Convert to PIL image\n",
    "        axes[i].imshow(img)\n",
    "        axes[i].axis('off')  # Hide the axes\n",
    "\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def give_me_images(model, diffusion, generate_random_tensor, device, num_images, data_dir):\n",
    "\n",
    "    labels, params = generate_random_tensor(data_dir)\n",
    "    \n",
    "    labels[-1] = -0.2 # 0.5\n",
    "        \n",
    "    # raise RunTimeError\n",
    "    labels = labels.to(device)\n",
    "    sampled_images = diffusion.sample(model, n=num_images, labels=labels, projection=False)\n",
    "    return sampled_images\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Usage:\n",
    "data_directory = \"/path/to/dir/Moments/Train\"\n",
    "sample = give_me_images(model, diffusion, generate_random_tensor, device, 10, data_directory)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_images(sample[:10])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchmetrics",
   "language": "python",
   "name": "torchmetrics"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
