{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction - Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.backends.backend_pdf import PdfPages\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.serif\": [\"Computer Modern Roman\"],  # Use preferred LaTeX font\n",
    "    \"font.size\": 14  # Adjust font size as needed\n",
    "})\n",
    "\n",
    "\n",
    "# Paths to images\n",
    "image_paths = [\n",
    "    \"/home/tim/Documents/XX_Perturbation_SHAP/perturbation_shap/figs/imgs/vigo_1.jpg\",\n",
    "    \"/home/tim/Documents/XX_Perturbation_SHAP/perturbation_shap/figs/imgs/sunny_2.jpg\",\n",
    "    \"/home/tim/Documents/XX_Perturbation_SHAP/perturbation_shap/figs/imgs/porsche_1.jpg\",\n",
    "]\n",
    "\n",
    "# Grid cell size\n",
    "grid_size = 8\n",
    "\n",
    "# Define cell for perturbation for each image\n",
    "selected_cells = [\n",
    "    (4, 2),  # For first image\n",
    "    (4, 5),  # For third image\n",
    "    (3, 4),  # For fourth image\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### RGB-Maximum Distance Calculation Room"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perturbation results saved to 'perturbation_results.pdf'\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# Helper function to crop to the largest square from the center\n",
    "def crop_to_square(image):\n",
    "    h, w, _ = image.shape\n",
    "    side = min(h, w)\n",
    "    start_x = (w - side) // 2\n",
    "    start_y = (h - side) // 2\n",
    "    return image[start_y : start_y + side, start_x : start_x + side]\n",
    "\n",
    "# Function to resize images to the smallest resolution while maintaining aspect ratio\n",
    "def resize_to_smallest(images):\n",
    "    min_height = min(image.shape[0] for image in images)\n",
    "    min_width = min(image.shape[1] for image in images)\n",
    "\n",
    "    resized_images = []\n",
    "    for image in images:\n",
    "        h, w, _ = image.shape\n",
    "        aspect_ratio = w / h\n",
    "\n",
    "        if aspect_ratio > 1:  # Wider than tall\n",
    "            new_width = int(aspect_ratio * min_height)\n",
    "            resized = cv2.resize(image, (new_width, min_height), interpolation=cv2.INTER_AREA)\n",
    "        else:  # Taller than wide or square\n",
    "            new_height = int(min_width / aspect_ratio)\n",
    "            resized = cv2.resize(image, (min_width, new_height), interpolation=cv2.INTER_AREA)\n",
    "\n",
    "        resized = crop_to_square(resized)  # Crop to square after resizing\n",
    "        resized_images.append(resized)\n",
    "\n",
    "    return resized_images\n",
    "\n",
    "\"\"\"\n",
    "# Function to calculate the color with maximum Euclidean distance by predefined color options\n",
    "def old_calculate_max_distance_color(mean_color):\n",
    "    colors = np.array([\n",
    "        [0, 0, 0],  # Black\n",
    "        [255, 255, 255],  # White\n",
    "        [255, 0, 0],  # Red\n",
    "        [0, 255, 0],  # Green\n",
    "        [0, 0, 255],  # Blue\n",
    "        [255, 255, 0],  # Yellow\n",
    "        [255, 0, 255],  # Magenta\n",
    "        [0, 255, 255],  # Cyan\n",
    "    ])\n",
    "    distances = np.linalg.norm(colors - mean_color, axis=1)\n",
    "    return colors[np.argmax(distances)]\n",
    "\n",
    "\n",
    "# Function to calculate the color with maximum distance in RGB space\n",
    "def calculate_max_distance_color(mean_color):\n",
    "    # RGB space max value is 255\n",
    "    # Calculate the \"opposite\" color by subtracting from max\n",
    "    max_distance_color = 255 - mean_color\n",
    "    return max_distance_color\n",
    "\"\"\"\n",
    "\n",
    "def calculate_max_distance_color(mean_rgb):\n",
    "    # Define all 8 corners of the RGB cube\n",
    "    corners = np.array([\n",
    "        [0, 0, 0],       # Black\n",
    "        [255, 0, 0],     # Red\n",
    "        [0, 255, 0],     # Green\n",
    "        [0, 0, 255],     # Blue\n",
    "        [255, 255, 0],   # Yellow\n",
    "        [255, 0, 255],   # Magenta\n",
    "        [0, 255, 255],   # Cyan\n",
    "        [255, 255, 255]  # White\n",
    "    ])\n",
    "    \n",
    "    # Calculate distances from the mean RGB to each corner\n",
    "    distances = np.linalg.norm(corners - mean_rgb, axis=1)\n",
    "    \n",
    "    # Find the farthest corner\n",
    "    farthest_corner = corners[np.argmax(distances)]\n",
    "    return farthest_corner\n",
    "\n",
    "\n",
    "\n",
    "# Function to apply perturbations\n",
    "def apply_perturbations(image, cell_coords, blur_kernel):\n",
    "    h, w, _ = image.shape\n",
    "    cell_h = h // grid_size\n",
    "    cell_w = w // grid_size\n",
    "\n",
    "    y1, x1 = cell_coords[0] * cell_h, cell_coords[1] * cell_w\n",
    "    y2, x2 = y1 + cell_h, x1 + cell_w\n",
    "\n",
    "    # Mask for the selected cell\n",
    "    mask = np.zeros((h, w), dtype=np.uint8)\n",
    "    mask[y1:y2, x1:x2] = 1\n",
    "\n",
    "    # Create perturbations\n",
    "    blurred = image.copy()\n",
    "    blurred[y1:y2, x1:x2] = cv2.blur(image[y1:y2, x1:x2], blur_kernel)\n",
    "    blurred = draw_grid(blurred)  # Add grid after blurring\n",
    "\n",
    "    inpaint1 = cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)\n",
    "    inpaint1 = draw_grid(inpaint1)  # Add grid after inpainting (Telea)\n",
    "\n",
    "    inpaint2 = cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_NS)\n",
    "    inpaint2 = draw_grid(inpaint2)  # Add grid after inpainting (NS)\n",
    "\n",
    "    blacked = image.copy()\n",
    "    blacked[mask == 1] = 0\n",
    "    blacked = draw_grid(blacked)  # Add grid after blacking out\n",
    "\n",
    "    # Calculate mean RGB of the cell\n",
    "    cell = image[y1:y2, x1:x2]\n",
    "    mean_color = np.mean(cell, axis=(0, 1))\n",
    "    max_distance_color = calculate_max_distance_color(mean_color)\n",
    "\n",
    "    # Apply \"Our Approach\"\n",
    "    our_approach = image.copy()\n",
    "    our_approach[y1:y2, x1:x2] = max_distance_color\n",
    "    our_approach = draw_grid(our_approach)  # Add grid after \"Our Approach\"\n",
    "\n",
    "    # Highlight the selected cell in red for visualization\n",
    "    image_with_red_cell = image.copy()\n",
    "    cv2.rectangle(image_with_red_cell, (x1, y1), (x2, y2), (213, 0, 28), 2)\n",
    "\n",
    "    return image_with_red_cell, blurred, inpaint1, inpaint2, blacked, our_approach\n",
    "\n",
    "# Function to draw grid on an image\n",
    "def draw_grid(image):\n",
    "    h, w, _ = image.shape\n",
    "    for i in range(1, grid_size):\n",
    "        cv2.line(image, (0, i * h // grid_size), (w, i * h // grid_size), (255, 255, 255), 1)\n",
    "        cv2.line(image, (i * w // grid_size, 0), (i * w // grid_size, h), (255, 255, 255), 1)\n",
    "    return image\n",
    "\n",
    "# Process images and save results to a PDF\n",
    "with PdfPages(\"/home/tim/Documents/XX_Perturbation_SHAP/perturbation_shap/figs/results/rgb_perturbation_results.pdf\") as pdf:\n",
    "    # Load all images\n",
    "    loaded_images = [cv2.imread(img_path) for img_path in image_paths]\n",
    "\n",
    "    # Resize all images to the smallest resolution\n",
    "    resized_images = resize_to_smallest(loaded_images)\n",
    "\n",
    "    # Create one large plot\n",
    "    fig, axes = plt.subplots(len(image_paths), 6, figsize=(24, 4 * len(image_paths)))\n",
    "\n",
    "    # Define a different Gaussian kernel size for each row\n",
    "    blur_kernels = [(55, 55), (55, 55), (55, 55), (55, 55)]\n",
    "\n",
    "    for row, (img, blur_kernel, cell_coords) in enumerate(zip(resized_images, blur_kernels, selected_cells)):\n",
    "        # Convert BGR to RGB for Matplotlib\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "        # Crop to square and draw grid\n",
    "        cropped_img = crop_to_square(img)\n",
    "        grid_img = draw_grid(cropped_img.copy())\n",
    "\n",
    "        # Apply perturbations\n",
    "        img_with_red_cell, blurred, inpaint1, inpaint2, blacked, our_approach = apply_perturbations(grid_img, cell_coords, blur_kernel)\n",
    "\n",
    "        axes[row, 0].imshow(img_with_red_cell)\n",
    "        if row == 0:\n",
    "            axes[row, 0].set_title(\"Original\")\n",
    "\n",
    "        axes[row, 1].imshow(blurred)\n",
    "        if row == 0:\n",
    "            axes[row, 1].set_title(\"Blurring\")\n",
    "\n",
    "        axes[row, 2].imshow(inpaint1)\n",
    "        if row == 0:\n",
    "            axes[row, 2].set_title(\"Inpaint Telea\")\n",
    "\n",
    "        axes[row, 3].imshow(inpaint2)\n",
    "        if row == 0:\n",
    "            axes[row, 3].set_title(\"Inpaint NS\")\n",
    "\n",
    "        axes[row, 4].imshow(blacked)\n",
    "        if row == 0:\n",
    "            axes[row, 4].set_title(\"UCR - Black\")\n",
    "\n",
    "        axes[row, 5].imshow(our_approach)\n",
    "        if row == 0:\n",
    "            axes[row, 5].set_title(\"Our Approach\")\n",
    "\n",
    "        for ax in axes[row]:\n",
    "            ax.axis(\"off\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    pdf.savefig(fig)\n",
    "    plt.close(fig)\n",
    "\n",
    "print(\"Perturbation results saved to 'perturbation_results.pdf'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "perturbation_shap",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
