{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import os\n",
    "# os.chdir(\"./playgrounds/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from classes.PnP_class import *\n",
    "from ours_model.model_loader import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Image object\n",
    "- **image_path**: Path to the image  \n",
    "- **forward_model_name**: Name of the forward model  \n",
    "- **forward_model_args**: Arguments for the forward model  \n",
    "- **noise_level**: Noise level to be added to the image  \n",
    "- **kernel_path**: Path to the kernel file (Custom forward model through convolution; not implemented yet)  \n",
    "- **color_mode**: `'RGB'` for color images, `'L'` for grayscale images  \n",
    "- **crop**: `True` to crop the image to 256x256  \n",
    "- **seed_val**: Seed value for reproducibility  \n",
    "- **save_path**: Path to save the results; default is `./results/`  \n",
    "\n",
    "> **Note**: Only the method `get_images()` will save the images when called with `save=True`.\n",
    "\n",
    "> The detailed description of the attributes can be found at the end of the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_id = 'CBSD10/0001'\n",
    "color_mode = 'RGB'\n",
    "algo = 'HQS'\n",
    "# color_mode = 'L'\n",
    "device = 'cuda:0'\n",
    "application_type = 'deblurring'\n",
    "image_path = \"../images/{}.png\".format(image_id)\n",
    "denoiser_name = 'OURS'\n",
    "save_path = \"../Experiments/Exp1a/{}/{}/{}/{}/\".format(\n",
    "    image_id, application_type, denoiser_name, algo)\n",
    "my_image = img_PnP(image_path, forward_model_name=application_type,\n",
    "                         forward_model_args={'scale_factor': 1,\n",
    "                                             'kernel_id': 7, 'device': device},\n",
    "                         noise_level=0.03, color_mode=color_mode, save_path=save_path, crop=True)\n",
    "_ = my_image.get_images(plot=True, save=True)\n",
    "print(\"Image shape:\", my_image.image.shape)\n",
    "\n",
    "denoiser_ours = load_model(\n",
    "    config_folder=\"../ours_model/configs-50g-new\",\n",
    "                            window_rad=None,\n",
    "                            device=device\n",
    "                            ).to(device) ### Model Object\n",
    "run_ours = OURS ### Model Infer Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### compute reference image z\n",
    "denoiser_args = {\n",
    "    'guide_image': my_image.observed,  # Guide image is the observed image\n",
    "                  'h': 0.03} ### Denoiser Args\n",
    "algo_params={'transpose': False, 'name': 'FBS', 'step_size': 2.0, 'clip': False}\n",
    "algo_params={'transpose': False, 'name': 'HQS', 'step_size': 5.0, 'clip': False}\n",
    "warm_up_iters = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "warm_up_iters = 35\n",
    "denoiser_args['h'] = 8.0/255.0\n",
    "algo_params['step_size'] = 6.5\n",
    "for i in range(warm_up_iters):\n",
    "    my_image.PnP(denoiser=run_ours,\n",
    "             denoiser_args=denoiser_args,\n",
    "             denoiser_object=denoiser_ours,\n",
    "             num_iterations=1,\n",
    "                plot_graphs=True,\n",
    "                plot_interval=101,\n",
    "                algo_params=algo_params,\n",
    "        )\n",
    "    denoiser_args['guide_image'] = my_image.reconstruction\n",
    "    my_image.start_image = my_image.reconstruction.copy()\n",
    "\n",
    "my_image.get_images(plot=True, save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_image.start_image = my_image.observed.copy()\n",
    "my_image.PnP(denoiser=run_ours,\n",
    "             denoiser_args=denoiser_args,\n",
    "             denoiser_object=denoiser_ours,\n",
    "             num_iterations=101,\n",
    "                plot_graphs=True,\n",
    "                plot_interval=50,\n",
    "                algo_params=algo_params,\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dn-arg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
