{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "0860b8c9-e6f3-437e-918c-63b191c365cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "import torch\n",
    "import lpips\n",
    "from skimage.metrics import mean_squared_error, structural_similarity\n",
    "from skimage.metrics import peak_signal_noise_ratio\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08cae41a-18bc-49ce-88bb-c8b4bd949cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = './output_celebahq_age_0.03'\n",
    "\n",
    "mod_src_dir = os.path.join(base_dir, 'Results', 'experiment', 'explanation', 'CC', 'CCF', 'CF')\n",
    "orig_src_dir = os.path.join(base_dir, 'Original', 'Correct')\n",
    "\n",
    "collected_orig_dir = './collected/originals'\n",
    "collected_mod_dir = './collected/modified'\n",
    "os.makedirs(collected_orig_dir, exist_ok=True)\n",
    "os.makedirs(collected_mod_dir, exist_ok=True)\n",
    "\n",
    "for fname in os.listdir(mod_src_dir):\n",
    "    mod_path = os.path.join(mod_src_dir, fname)\n",
    "    orig_path = os.path.join(orig_src_dir, fname)\n",
    "\n",
    "    if os.path.isfile(mod_path) and os.path.isfile(orig_path):\n",
    "        shutil.copy(mod_path, os.path.join(collected_mod_dir, fname))\n",
    "        shutil.copy(orig_path, os.path.join(collected_orig_dir, fname))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "bae0ad4b-c110-4368-9513-78b05b239451",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.5]*3, [0.5]*3)\n",
    "])\n",
    "\n",
    "def load_img(path):\n",
    "    img = Image.open(path).convert('RGB')\n",
    "    return transform(img).unsqueeze(0).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "4490b863-44a2-48ee-8c65-875d028a37a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]\n",
      "Loading model from: /home/qyr/anaconda3/lib/python3.12/site-packages/lpips/weights/v0.1/vgg.pth\n"
     ]
    }
   ],
   "source": [
    "lpips_model = lpips.LPIPS(net='vgg').cuda()\n",
    "orig_files = sorted(os.listdir(collected_orig_dir))\n",
    "mod_files = sorted(os.listdir(collected_mod_dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "a1469b1d-9afc-4b15-8618-60f0c3ba31f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg LPIPS: 0.047316651939065195\n",
      "Avg PSNR: 33.10051010117345\n",
      "Avg ssim: 0.9542420029441536\n"
     ]
    }
   ],
   "source": [
    "lpips_scores = []\n",
    "ssim_scores = []\n",
    "psnr_scores = []\n",
    "for o, m in zip(orig_files, mod_files):\n",
    "    x = load_img(os.path.join(collected_orig_dir, o))\n",
    "    x_prime = load_img(os.path.join(collected_mod_dir, m))\n",
    "    dist = lpips_model(x, x_prime).item()\n",
    "    lpips_scores.append(dist)\n",
    "\n",
    "    x_np = (x.squeeze(0) * 0.5 + 0.5).clamp(0, 1)\n",
    "    x_np = (x_np * 255).byte().permute(1, 2, 0).cpu().numpy()\n",
    "\n",
    "    x_prime_np = (x_prime.squeeze(0) * 0.5 + 0.5).clamp(0, 1)\n",
    "    x_prime_np = (x_prime_np * 255).byte().permute(1, 2, 0).cpu().numpy()\n",
    "\n",
    "    psnr = peak_signal_noise_ratio(x_np, x_prime_np, data_range=255.0)\n",
    "    psnr_scores.append(psnr)\n",
    "\n",
    "    # SSIM（注意 multichannel=True 对于彩色图像很重要）\n",
    "    ssim = structural_similarity(x_np, x_prime_np, data_range=255, channel_axis=-1)\n",
    "    ssim_scores.append(ssim)\n",
    "\n",
    "print(\"Avg LPIPS:\", sum(lpips_scores) / len(lpips_scores))\n",
    "print(\"Avg PSNR:\", sum(psnr_scores) / len(psnr_scores))\n",
    "print(\"Avg ssim:\", sum(ssim_scores) / len(ssim_scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a1b5f0cb-af32-4b71-b711-d746d8487e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree('./collected')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4226f785-9e1d-44c0-b585-be475b81427e",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
