{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jovyan/conda/dfs/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/jovyan/conda/dfs/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import os\n",
    "import io\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import rewards\n",
    "import csv\n",
    "\n",
    "aesthetic_fn = rewards.aesthetic_score(torch_dtype = torch.float32, device = 'cuda')\n",
    "hps_fn = rewards.hps_score(inference_dtype = torch.float32, device = 'cuda')\n",
    "imagereward = rewards.ImageReward(inference_dtype = torch.float32, device = 'cuda')\n",
    "pick_fn = rewards.PickScore(inference_dtype = torch.float32, device = 'cuda')\n",
    "clip_fn = rewards.clip_score(inference_dtype = torch.float32, device = 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]\n",
      "Loading model from: /home/jovyan/conda/dfs/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "from PIL import Image\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from scipy.spatial.distance import pdist\n",
    "import csv\n",
    "import lpips\n",
    "from torchvision import transforms\n",
    "\n",
    "# Load the CLIP model and processor (using openai/clip-vit-large-patch14)\n",
    "model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "\n",
    "# Load LPIPS model\n",
    "lpips_model = lpips.LPIPS(net='alex')\n",
    "\n",
    "# Device configuration\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = model.to(device)\n",
    "lpips_model = lpips_model.to(device)\n",
    "\n",
    "# Image preprocessing function\n",
    "def preprocess_image(image_path):\n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    return processor(images=image, return_tensors=\"pt\")['pixel_values'].squeeze(0)\n",
    "\n",
    "# Function to preprocess image for LPIPS\n",
    "def preprocess_image_lpips(image_path):\n",
    "    transform = transforms.Compose([\n",
    "        transforms.Resize((224, 224)),\n",
    "        transforms.ToTensor(),\n",
    "    ])\n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    return transform(image).unsqueeze(0)\n",
    "\n",
    "# Function to calculate CLIP-based metrics and LPIPS\n",
    "def calculate_metrics(image_folder, K=20):\n",
    "    image_folder = os.path.join(image_folder, \"eval_vis\")\n",
    "    embeddings = []\n",
    "    lpips_images = []\n",
    "    image_files = [os.path.join(image_folder, file) for file in os.listdir(image_folder) if (file.endswith(('png', 'jpg', 'jpeg')) and not \"ess\" in file and not \"intermediate_rewards\" in file)]\n",
    "\n",
    "    if len(image_files) == 0:\n",
    "        raise ValueError(f\"No images found in the folder: {image_folder}\")\n",
    "\n",
    "    # Preprocess images and compute embeddings\n",
    "    for image_path in tqdm(image_files):\n",
    "        try:\n",
    "            # For CLIP\n",
    "            pixel_values = preprocess_image(image_path).unsqueeze(0).to(device)\n",
    "            with torch.no_grad():\n",
    "                embedding = model.get_image_features(pixel_values).cpu().numpy().squeeze()\n",
    "            embeddings.append(embedding)\n",
    "\n",
    "            # For LPIPS\n",
    "            lpips_image = preprocess_image_lpips(image_path).to(device)\n",
    "            lpips_images.append(lpips_image)\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing image {image_path}: {e}\")\n",
    "            continue\n",
    "\n",
    "    embeddings = np.array(embeddings)\n",
    "\n",
    "    if len(embeddings) == 0:\n",
    "        raise ValueError(\"No embeddings were generated. Please check your images and preprocessing steps.\")\n",
    "    \n",
    "    # ---- Calculate Mean Pairwise Distance (CLIP-based) ----\n",
    "    pairwise_distances = pdist(embeddings, metric='cosine')\n",
    "    mean_distance = np.mean(pairwise_distances)\n",
    "    num_distances = pairwise_distances.size\n",
    "    std_error = np.std(pairwise_distances) / np.sqrt(num_distances)\n",
    "    \n",
    "    # ---- Calculate Truncated CLIP Entropy (TCE) ----\n",
    "    covariance_matrix = np.cov(embeddings, rowvar=False)\n",
    "    eigenvalues = np.linalg.eigvalsh(covariance_matrix)[-K:]\n",
    "    TCE_K = (K / 2) * np.log(2 * np.pi * np.e) + (1 / 2) * np.sum(np.log(eigenvalues))\n",
    "    \n",
    "    # ---- Calculate LPIPS-based diversity ----\n",
    "    lpips_distances = []\n",
    "    num_images = len(lpips_images)\n",
    "    for i in range(num_images):\n",
    "        for j in range(i+1, num_images):\n",
    "            with torch.no_grad():\n",
    "                distance = lpips_model(lpips_images[i], lpips_images[j]).item()\n",
    "            lpips_distances.append(distance)\n",
    "    \n",
    "    mean_lpips = np.mean(lpips_distances)\n",
    "    std_lpips = np.std(lpips_distances)\n",
    "    \n",
    "    return mean_distance, std_error, TCE_K, mean_lpips, std_lpips"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_folder = \"logs/tdpo/pick/2024.10.01_18.23.52\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "018_a black motorcycle is parked by the side of the road | reward: 0.2263115793466568 | hps: 0.2904018461704254.png\n",
      "5.5359344482421875\n",
      "012_A man sitting at a table in front of bowls of spices. | reward: 0.21969418227672577 | hps: 0.2655247449874878.png\n",
      "6.492404460906982\n",
      "028_A man driving a luggage cart sitting on top of a runway. | reward: 0.22203242778778076 | hps: 0.2843267321586609.png\n",
      "5.146237373352051\n",
      "025_Two small planes sitting near each other on a run way. | reward: 0.18042421340942383 | hps: 0.23171466588974.png\n",
      "5.171357154846191\n",
      "011_there is a red bus that has a mans face on it | reward: 0.21346686780452728 | hps: 0.2765390872955322.png\n",
      "5.565062999725342\n",
      "019_there is a bus that has a bike attached to the front | reward: 0.20295165479183197 | hps: 0.24722620844841003.png\n",
      "5.390223503112793\n",
      "021_A toy elephant is sitting inside a wooden car toy. | reward: 0.23601214587688446 | hps: 0.3074624240398407.png\n",
      "5.9249067306518555\n",
      "001_A bus stopped on the side of the road while people board it. | reward: 0.22003355622291565 | hps: 0.27240437269210815.png\n",
      "5.209050178527832\n",
      "022_A TV sitting on top of a wooden stand. | reward: 0.2269500195980072 | hps: 0.28757303953170776.png\n",
      "5.042357444763184\n",
      "004_a dog with a plate of food on the ground | reward: 0.21278992295265198 | hps: 0.27024412155151367.png\n",
      "5.986759662628174\n",
      "006_A bird that is sitting in the rim of a tire. | reward: 0.24557647109031677 | hps: 0.306388258934021.png\n",
      "5.281464099884033\n",
      "029_A man sitting at a table in front of bowls of spices. | reward: 0.20907042920589447 | hps: 0.2606205344200134.png\n",
      "6.305418014526367\n",
      "014_An eye level counter-view shows blue tile, a faucet, dish scrubbers, bowls, a squirt bottle and similar kitchen items. | reward: 0.20683149993419647 | hps: 0.24000772833824158.png\n",
      "5.123661041259766\n",
      "010_Street merchant with bowls of grains and other products. | reward: 0.21805942058563232 | hps: 0.2634230852127075.png\n",
      "5.969822406768799\n",
      "009_A small bathroom with a tub, toilet, sink, and a laundry basket are shown. | reward: 0.20973145961761475 | hps: 0.2617783546447754.png\n",
      "5.243871688842773\n",
      "026_Residential bathroom with commode and shower and plain white walls. | reward: 0.2200029045343399 | hps: 0.2481478750705719.png\n",
      "5.136394023895264\n",
      "016_A pair of planes parked in a small rural airfield. | reward: 0.17797575891017914 | hps: 0.2226017713546753.png\n",
      "4.875957489013672\n",
      "005_The black motorcycle is parked on the sidewalk. | reward: 0.18742510676383972 | hps: 0.23827466368675232.png\n",
      "4.956569194793701\n",
      "003_there is a red bus that has a mans face on it | reward: 0.21511653065681458 | hps: 0.27950939536094666.png\n",
      "5.65103816986084\n",
      "015_People getting on a bus in the city | reward: 0.2402249276638031 | hps: 0.27848026156425476.png\n",
      "5.430393218994141\n",
      "017_Three people are preparing a meal in a small kitchen. | reward: 0.21512535214424133 | hps: 0.25484102964401245.png\n",
      "5.811896324157715\n",
      "027_a couple of people in uniforms are sitting together | reward: 0.22901883721351624 | hps: 0.2947719097137451.png\n",
      "5.541792869567871\n",
      "013_a black motorcycle is parked by the side of the road | reward: 0.21849696338176727 | hps: 0.27662134170532227.png\n",
      "5.49644660949707\n",
      "030_classic cars on a city street with people and a dog | reward: 0.23075035214424133 | hps: 0.2853015959262848.png\n",
      "6.105373859405518\n",
      "002_Three people are preparing a meal in a small kitchen. | reward: 0.2030303031206131 | hps: 0.24375441670417786.png\n",
      "5.80161190032959\n",
      "031_A bird that is sitting in the rim of a tire. | reward: 0.24216611683368683 | hps: 0.30761009454727173.png\n",
      "5.380100250244141\n",
      "024_there is a bathroom that has a lot of things on the floor | reward: 0.2343887984752655 | hps: 0.25823792815208435.png\n",
      "5.060821533203125\n",
      "008_there is a red bus that has a mans face on it | reward: 0.21813908219337463 | hps: 0.2808381915092468.png\n",
      "5.542859077453613\n",
      "020_A bird that is sitting in the rim of a tire. | reward: 0.2269342690706253 | hps: 0.293695867061615.png\n",
      "5.159752368927002\n",
      "023_Three people are preparing a meal in a small kitchen. | reward: 0.22745615243911743 | hps: 0.25838205218315125.png\n",
      "5.60786247253418\n",
      "000_A white toilet in a generic public bathroom stall. | reward: 0.23479938507080078 | hps: 0.26841050386428833.png\n",
      "5.0223188400268555\n",
      "007_A bunch of people posing with some bikes. | reward: 0.22777429223060608 | hps: 0.28350645303726196.png\n",
      "5.771385192871094\n",
      "Finished evaluating images in logs/tdpo/pick/2024.10.01_18.23.52\n",
      "Aesthetic score:  5.491909518837929\n",
      "Aesthetic score std:  0.3995667209327642\n",
      "HPS score:  0.2699212459847331\n",
      "HPS score std:  0.021868288609785395\n",
      "Image reward score:  0.08745209313929081\n",
      "Image reward score std:  1.0008913563159652\n",
      "Pick score:  0.21870136866346002\n",
      "Pick score std:  0.015924428839133666\n",
      "Clip score:  0.22988681285642087\n",
      "Clip score std:  0.0449489838036817\n"
     ]
    }
   ],
   "source": [
    "aesthetic_score = []\n",
    "hps_score = []\n",
    "imagereward_score = []\n",
    "pick_score = []\n",
    "clip_score = []\n",
    "image_names = [file for file in os.listdir(img_folder + \"/eval_vis\") if (file.endswith(('png', 'jpg', 'jpeg')) and not \"ess\" in file and not \"intermediate_rewards\" in file)]\n",
    "for image_name in image_names:\n",
    "\n",
    "    image_path = os.path.join(img_folder + \"/eval_vis\", image_name)\n",
    "\n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    image = torchvision.transforms.ToTensor()(image).unsqueeze(0).to('cuda')\n",
    "\n",
    "    prompt = image_name.split(\"|\")[0].split(\"_\")[-1][:-1]\n",
    "    # print(prompt)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        clip_score.append(clip_fn(image, prompt).item())\n",
    "        aesthetic_score.append(aesthetic_fn(image, prompt).item())\n",
    "        hps_score.append(hps_fn(image, prompt).item())\n",
    "        imagereward_score.append(imagereward(image, prompt).item())\n",
    "        pick_score.append(pick_fn(image, prompt).item())\n",
    "\n",
    "    print(image_name)\n",
    "    print(aesthetic_fn(image, prompt).item())\n",
    "\n",
    "print(f\"Finished evaluating images in {img_folder}\")\n",
    "print(\"Aesthetic score: \", np.mean(aesthetic_score))\n",
    "print(\"Aesthetic score std: \", np.std(aesthetic_score))\n",
    "print(\"HPS score: \", np.mean(hps_score))\n",
    "print(\"HPS score std: \", np.std(hps_score))\n",
    "print(\"Image reward score: \", np.mean(imagereward_score))\n",
    "print(\"Image reward score std: \", np.std(imagereward_score))\n",
    "print(\"Pick score: \", np.mean(pick_score))\n",
    "print(\"Pick score std: \", np.std(pick_score))\n",
    "print(\"Clip score: \", np.mean(clip_score))\n",
    "print(\"Clip score std: \", np.std(clip_score))\n",
    "    \n",
    "# Save the results to a text file\n",
    "names = [\"Aesthetic score\", \"Aesthetic score std\", \"HPS score\", \"HPS score std\",\n",
    "         \"Image reward score\", \"Image reward score std\", \"Pick score\", \"Pick score std\", \"CLIP score\", \"CLIP score std\"]\n",
    "\n",
    "values = [np.mean(aesthetic_score), np.std(aesthetic_score),\n",
    "          np.mean(hps_score), np.std(hps_score),\n",
    "          np.mean(imagereward_score), np.std(imagereward_score),\n",
    "          np.mean(pick_score), np.std(pick_score),\n",
    "          np.mean(clip_score), np.std(clip_score)]\n",
    "\n",
    "# Format the values to 5 decimal places\n",
    "formatted_values = [f\"{v:.5f}\" for v in values]\n",
    "\n",
    "with open(os.path.join(img_folder, \"eval_results.csv\"), \"w\", newline='') as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow(names)\n",
    "    writer.writerow(formatted_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 32/32 [00:03<00:00,  9.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished evaluating images in logs/tdpo/pick/2024.10.01_18.23.52\n",
      "Mean Pairwise Distance (CLIP-based Diversity Metric): 0.3427164799840677\n",
      "Standard Error of the Distance: 0.004059637652523934\n",
      "Truncated CLIP Entropy (TCE): 43.54003852285079\n",
      "Mean LPIPS Distance: 0.5601368356616266\n",
      "Standard Deviation of LPIPS Distance: 0.05981753115005913\n"
     ]
    }
   ],
   "source": [
    "# Calculate metrics\n",
    "try:\n",
    "    mean_distance, std_error, TCE, mean_lpips, std_lpips = calculate_metrics(img_folder, K=20)\n",
    "    print(f\"Finished evaluating images in {img_folder}\")\n",
    "    print(f\"Mean Pairwise Distance (CLIP-based Diversity Metric): {mean_distance}\")\n",
    "    print(f\"Standard Error of the Distance: {std_error}\")\n",
    "    print(f\"Truncated CLIP Entropy (TCE): {TCE}\")\n",
    "    print(f\"Mean LPIPS Distance: {mean_lpips}\")\n",
    "    print(f\"Standard Deviation of LPIPS Distance: {std_lpips}\")\n",
    "\n",
    "    # Save the results to a CSV file\n",
    "    names = [\"Mean Pairwise Distance (CLIP)\", \"Standard Error of the Distance (CLIP)\", \n",
    "             \"Truncated CLIP Entropy (TCE)\", \"Mean LPIPS Distance\", \"Std Dev LPIPS Distance\"]\n",
    "    values = [mean_distance, std_error, TCE, mean_lpips, std_lpips]\n",
    "\n",
    "    # Format the values to 5 decimal places\n",
    "    formatted_values = [f\"{v:.5f}\" for v in values]\n",
    "\n",
    "    with open(os.path.join(img_folder, \"eval_diversity_results.csv\"), \"w\", newline='') as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow(names)\n",
    "        writer.writerow(formatted_values)\n",
    "\n",
    "except Exception as e:\n",
    "    print(f\"An error occurred: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dfs",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
