{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "from PIL import Image\n",
    "import torch\n",
    "from torchvision.transforms import  v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_attack(image_folder: str, attacker: torch.nn.Module, output_folder_name: str):\n",
    "    \"\"\"\n",
    "    Apply a distortion attack to all images in a folder and save the results.\n",
    "    \n",
    "    Args:\n",
    "        image_folder (str): Path to the folder containing watermarked images\n",
    "        attacker (torch.nn.Module): Transformation to apply (e.g., GaussianBlur, JPEGCompression)\n",
    "        output_folder_name (str): Name of the folder to save attacked images\n",
    "    \"\"\"\n",
    "    # Create output folder in the same directory as input folder\n",
    "    input_path = Path(image_folder)\n",
    "    output_path = input_path.parent / output_folder_name\n",
    "    output_path.mkdir(exist_ok=True)\n",
    "    \n",
    "    # Convert to tensor and back transform\n",
    "    to_tensor = v2.ToTensor()\n",
    "    to_pil = v2.ToPILImage()\n",
    "    \n",
    "    # Process each image in the folder\n",
    "    for img_path in input_path.glob('*'):\n",
    "        if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:\n",
    "            # Load and convert image to tensor\n",
    "            img = Image.open(img_path)\n",
    "            img_tensor = to_tensor(img)\n",
    "            \n",
    "            # Convert to uint8 (0-255 range)\n",
    "            img_tensor = (img_tensor * 255).to(torch.uint8)\n",
    "            \n",
    "            # Apply attack\n",
    "            attacked_tensor = attacker(img_tensor.unsqueeze(0)).squeeze(0)\n",
    "            \n",
    "            # Convert back to float32 (0-1 range) for ToPILImage\n",
    "            attacked_tensor = attacked_tensor.float() / 255.0\n",
    "            \n",
    "            # Convert back to PIL and save\n",
    "            attacked_img = to_pil(attacked_tensor)\n",
    "            \n",
    "            # Save with same name but always as PNG\n",
    "            output_file = output_path / f\"{img_path.stem}.png\"\n",
    "            attacked_img.save(output_file, format='PNG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/local/scratch/g/liu3351/conda/envs/eraser/lib/python3.9/site-packages/torchvision/transforms/v2/_deprecated.py:42: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.Output is equivalent up to float precision.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "source_path = \"dog_wm\"\n",
    "jpeg_compression = v2.JPEG(quality=1)\n",
    "\n",
    "apply_attack(source_path, jpeg_compression, \"jpeg_low\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_path = \"dog_wm\"\n",
    "jpeg_compression = v2.JPEG(quality=50)\n",
    "\n",
    "apply_attack(source_path, jpeg_compression, \"jpeg_high\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed and saved: gaus_low/0.png\n",
      "Processed and saved: gaus_low/1.png\n",
      "Processed and saved: gaus_low/2.png\n",
      "Processed and saved: gaus_low/3.png\n",
      "Processed and saved: gaus_low/4.png\n"
     ]
    }
   ],
   "source": [
    "gaus = v2.GaussianBlur(kernel_size=15, sigma=10.0)\n",
    "\n",
    "apply_attack(source_path, gaus, \"gaus_low\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processed and saved: gaus_high/0.png\n",
      "Processed and saved: gaus_high/1.png\n",
      "Processed and saved: gaus_high/2.png\n",
      "Processed and saved: gaus_high/3.png\n",
      "Processed and saved: gaus_high/4.png\n"
     ]
    }
   ],
   "source": [
    "gaus = v2.GaussianBlur(kernel_size=5, sigma=10.0)\n",
    "\n",
    "apply_attack(source_path, gaus, \"gaus_high\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "eraser",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
