{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "59b6061a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import SimpleITK as sitk\n",
    "from model.dncnn import *\n",
    "from model.Loss import *\n",
    "from model.PIRATE_structure import *\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "86bec97f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_checkpoint(model, checkpoint_PATH, device):\n",
    "    checkpoint_PATH = checkpoint_PATH+\"save.pth.tar\"\n",
    "    model_CKPT = torch.load(checkpoint_PATH, map_location=device)\n",
    "    model.load_state_dict(model_CKPT['state_dict'], strict=False)\n",
    "    print('loading checkpoint!')\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "73b117a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = \"./data\"\n",
    "model_path = \"./model/save/\"\n",
    "target_path = \"./output/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5e15d209",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_list = [[\"moving\",\"fixed\"]]\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8bf57aed",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"gamma_inti\":5e5,\n",
    "    \"tau_inti\":1e-7,#1e-7\n",
    "    \"iteration\":500,\n",
    "    \"image_shape\":[160, 192, 224],\n",
    "    \"weight_grad\":5e-1\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "30961f85",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8073c00c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading checkpoint!\n"
     ]
    }
   ],
   "source": [
    "resize = ResizeTransform(1/2, 3)\n",
    "resize = resize.to(device)\n",
    "\n",
    "denoiser = DnCNN()\n",
    "ForwardIteration = URED(denoiser,config)\n",
    "PIRATE = DEQ(ForwardIteration).to(device)\n",
    "PIRATE = load_checkpoint(PIRATE, model_path, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f7b4a19d",
   "metadata": {},
   "outputs": [],
   "source": [
    "PIRATE.eval()\n",
    "with torch.no_grad():        \n",
    "    moving = sitk.ReadImage('./data/moving.nii.gz')\n",
    "    moving = sitk.GetArrayFromImage(moving)\n",
    "    fixed = sitk.ReadImage('./data/fixed.nii.gz')\n",
    "    fixed = sitk.GetArrayFromImage(fixed)\n",
    "    \n",
    "    moving = torch.from_numpy(moving).view(1, 1, moving.shape[-3], moving.shape[-2], moving.shape[-1]).to(device)\n",
    "    fixed = torch.from_numpy(fixed).view(1, 1, fixed.shape[-3], fixed.shape[-2], fixed.shape[-1]).to(device)\n",
    "        \n",
    "    field = torch.zeros((1, 3, config['image_shape'][-3]//2,config['image_shape'][-2]//2, config['image_shape'][-1]//2), requires_grad=True, device = device)\n",
    "        \n",
    "    field_hat, forward_iter, forward_res = PIRATE(field, moving, fixed)  \n",
    "        \n",
    "    field_full = resize(field_hat)\n",
    "        \n",
    "    transformer = SpatialTransformer(config['image_shape'])\n",
    "    transformer = transformer.to(device)\n",
    "    warped_image = transformer(moving, field_full, return_phi=False)\n",
    "        \n",
    "    warped_np = warped_image.view(warped_image.shape[-3],warped_image.shape[-2],warped_image.shape[-1]).detach().to(\"cpu\")    \n",
    "        \n",
    "    out = sitk.GetImageFromArray(warped_np)\n",
    "    sitk.WriteImage(out,target_path + 'warped_image.nii.gz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83848b53",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
