{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, re\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import pydicom\n",
    "from pydicom.dataset import Dataset, FileMetaDataset\n",
    "import imageio\n",
    "from scipy.ndimage import median_filter\n",
    "from numpy.lib.stride_tricks import sliding_window_view\n",
    "from concurrent.futures import ThreadPoolExecutor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_image(image, pad_size):\n",
    "    # PAP section: Zero-pad the input image to handle local patch boundaries.\n",
    "    # (Paper: \"Patch-Wise Adaptive Post-Processing\" – Calculation of Difference Map & Local Patch Classification)\n",
    "    c, h, w = image.shape\n",
    "    padded_image = np.zeros((c, h + 2 * pad_size, w + 2 * pad_size), dtype=image.dtype)\n",
    "    padded_image[:, pad_size:pad_size + h, pad_size:pad_size + w] = image\n",
    "    return padded_image\n",
    "\n",
    "def process_image_with_window(image, pad_size, std_factor, threshold, org_std, cnt,\n",
    "                              noise_secnd, noise_third, noise_full):\n",
    "    win_size = 2 * pad_size + 1\n",
    "    # 1) Apply sliding window to extract local patches.\n",
    "    # (Paper: \"Local Patch Classification\" – description before Eq.)\n",
    "    padded_image = pad_image(image, pad_size)\n",
    "    windowed = sliding_window_view(padded_image, (win_size, win_size), axis=(1, 2))\n",
    "\n",
    "    # 2) Compute count of pixels above threshold and std deviation in each patch.\n",
    "    # (Paper: \"Local Patch Classification\": calculation of µ_v and σ_v)\n",
    "    patch_thres = (windowed >= threshold).astype(np.uint8)\n",
    "    pixel_counts = np.sum(patch_thres, axis=(-2, -1))\n",
    "    pixel_stds = np.std(windowed, axis=(-2, -1))\n",
    "\n",
    "    # 3) Classify patches into vessel-dominant, normal, and background.\n",
    "    # (Paper: \"Local Patch Classification\" – intensity/variability criteria)\n",
    "    mask_cnt = pixel_counts >= cnt                    # sufficient high-intensity pixel count\n",
    "    mask_std = pixel_stds >= org_std * std_factor     # sufficient variability\n",
    "\n",
    "    cond_high = mask_cnt & mask_std   # vessel-dominant\n",
    "    cond_mid  = mask_cnt & ~mask_std  # normal\n",
    "    cond_low  = ~mask_cnt             # background-dominant\n",
    "\n",
    "    # 4) Correct the difference map based on patch type.\n",
    "    # noise_third: preserve vessels, noise_secnd: suppress noise.\n",
    "    # (Paper: \"Adaptive Histogram-Based Correction\" – D_corrected equation)\n",
    "    result_image = np.copy(image)\n",
    "    result_image[cond_high] -= noise_third[cond_high]\n",
    "    result_image[cond_mid]  -= noise_secnd[cond_mid]\n",
    "    result_image[cond_low]  -= noise_secnd[cond_low]\n",
    "\n",
    "    # 5) Clip results to [0,1].\n",
    "    np.clip(result_image, 0.0, 1.0, out=result_image)\n",
    "    return result_image\n",
    "\n",
    "def process_slice(idx, slice_, rec_slice, pad_size, std_factor, threshold, cnt):\n",
    "    # 6) Compute the difference map D: D = y_i − ŷ_i.\n",
    "    # (Paper: \"Calculation of Difference Map\" – Eq. (8))\n",
    "    diff = slice_ - rec_slice\n",
    "    std = np.std(diff)    # global noise standard deviation\n",
    "    mean = np.mean(diff)  # global noise mean\n",
    "\n",
    "    # 7) Generate noise masks based on ±2σ and ±3σ thresholds.\n",
    "    pos_two, neg_two = mean + 2 * std, mean - 2 * std\n",
    "    pos_thr, neg_thr = mean + 3 * std, mean - 3 * std\n",
    "\n",
    "    # noise_secnd: mask values within ±2σ.\n",
    "    noise_secnd = np.where((diff > pos_two) | (diff < neg_two), 0.0, diff)\n",
    "    # noise_third: mask values within ±3σ.\n",
    "    noise_third = np.where((diff > pos_thr) | (diff < neg_thr), 0.0, diff)\n",
    "\n",
    "    # 8) Apply PAP module to this slice.\n",
    "    processed_slice = process_image_with_window(\n",
    "        slice_[np.newaxis], pad_size, std_factor, threshold, std, cnt,\n",
    "        noise_secnd[np.newaxis], noise_third[np.newaxis], diff[np.newaxis]\n",
    "    )[0]\n",
    "\n",
    "    return idx, processed_slice\n",
    "\n",
    "def denoise_direction(volume, rec_volume, pad_size, std_factor, threshold, cnt):\n",
    "    # 9) Apply slice-wise denoising along one axis (sagittal/coronal/axial).\n",
    "    # (Paper: sequential PAP application along each axis)\n",
    "    processed_slices = [None] * len(volume)\n",
    "    args_list = [(idx, volume[idx], rec_volume[idx], pad_size, std_factor, threshold, cnt)\n",
    "                 for idx in range(len(volume))]\n",
    "\n",
    "    with ThreadPoolExecutor() as executor:\n",
    "        futures = [executor.submit(process_slice, *args) for args in args_list]\n",
    "        for future in futures:\n",
    "            idx, processed = future.result()\n",
    "            processed_slices[idx] = processed\n",
    "\n",
    "    return np.stack(processed_slices)\n",
    "\n",
    "def denoise_3d_slices(sagittal_vol, coronal_vol, axial_vol,\n",
    "                      rec_sag_vol, rec_cor_vol, rec_axi_vol,\n",
    "                      pad_size, std_factor, threshold, cnt):\n",
    "    # 10) Apply PAP sequentially: sagittal → coronal → axial.\n",
    "    sagittal_processed = denoise_direction(sagittal_vol, rec_sag_vol,\n",
    "                                            pad_size, std_factor, threshold, cnt)\n",
    "\n",
    "    coronal_input = sagittal_processed.transpose(2, 1, 0)\n",
    "    coronal_processed = denoise_direction(coronal_input, rec_cor_vol,\n",
    "                                          pad_size, std_factor, threshold, cnt)\n",
    "\n",
    "    axial_input = coronal_processed.transpose(1, 0, 2)\n",
    "    axial_processed = denoise_direction(axial_input, rec_axi_vol,\n",
    "                                        pad_size, std_factor, threshold, cnt)\n",
    "\n",
    "    return axial_processed\n",
    "\n",
    "def save_slices_as_png(volume, save_dir):\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    volume_clipped = np.clip(volume, 0.0, 1.0)\n",
    "    volume_uint8 = (volume_clipped * 255).astype(np.uint8)\n",
    "\n",
    "    for idx in range(volume.shape[0]):\n",
    "        filename = os.path.join(save_dir, f\"slice_{idx:03d}.png\")\n",
    "        imageio.imwrite(filename, volume_uint8[idx])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# root\n",
    "\n",
    "ori_axi_root = \"/original_noisy_axial_output_slices/\"\n",
    "rec_axi_root = \"/axial_output_slices/\"\n",
    "rec_cor_root = \"/coronal_output_slices/\"\n",
    "rec_sag_root = \"/sagittal_output_slices/\"\n",
    "\n",
    "dst_denoised = 'path_for_results'\n",
    "if not os.path.exists(dst_denoised):\n",
    "    os.makedirs(dst_denoised)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "\n",
    "rec_axi_list = [f for f in os.listdir(rec_axi_root)]\n",
    "rec_axi_list_sorted = sorted(rec_axi_list, key=num_key)[:248]\n",
    "rec_cor_list = [f for f in os.listdir(rec_cor_root)]\n",
    "rec_cor_list_sorted = sorted(rec_cor_list, key=num_key)\n",
    "rec_sag_list = [f for f in os.listdir(rec_sag_root)]\n",
    "rec_sag_list_sorted = sorted(rec_sag_list, key=num_key)\n",
    "\n",
    "rec_axi_volume = np.stack([cv2.imread(rec_axi_root + fname, cv2.IMREAD_GRAYSCALE) for fname in rec_axi_list_sorted])\n",
    "rec_cor_volume = np.stack([cv2.imread(rec_cor_root + fname, cv2.IMREAD_GRAYSCALE) for fname in rec_cor_list_sorted])\n",
    "rec_sag_volume = np.stack([cv2.imread(rec_sag_root + fname, cv2.IMREAD_GRAYSCALE) for fname in rec_sag_list_sorted])\n",
    "\n",
    "axial_volume = np.load(f'{ori_axi_root}/org_volume.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(axial_volume.shape)\n",
    "# print(rec_axi_volume.shape)\n",
    "# print(rec_cor_volume.shape)\n",
    "# print(rec_sag_volume.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "axi_vol = axial_volume[:,2:-2,2:-2] # 248,796,620\n",
    "cor_vol = axi_vol.transpose(1,0,2) #  (796, 248, 620)\n",
    "sag_vol = axi_vol.transpose(2,0,1) #(620, 248, 796)\n",
    "\n",
    "rec_axi_volume = rec_axi_volume[:,2:-2,2:-2] # 248,796,620\n",
    "rec_cor_volume = rec_cor_volume[:,:,2:-2] # (796, 248, 620)\n",
    "rec_sag_volume = rec_sag_volume[:,:,2:-2] # (620, 248, 796)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "w, t, s, c = 15, 0.3, 1.5, 180\n",
    "denoised_volume = denoise_3d_slices(sag_vol, cor_vol, axi_vol, rec_sag_volume, rec_cor_volume, rec_axi_volume, w, s, t, c) # 248,796,620"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save results\n",
    "\n",
    "save_slices_as_png(denoised_volume, f\"{dst_denoised}/\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
