{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import open_clip\n",
    "from PIL import Image\n",
    "import os\n",
    "import torch.nn.functional as nnf\n",
    "from tqdm import tqdm\n",
    "import csv\n",
    "sdxl = False\n",
    "gpu_num = 0\n",
    "device = torch.device(f\"cuda:{gpu_num}\") if torch.cuda.is_available() else \"cpu\"\n",
    "model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K', device=device)\n",
    "tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')\n",
    "\n",
    "\n",
    "# Define heads perturbed\n",
    "if not sdxl:\n",
    "    heads_perturbed = [0, 1, 11, 21, 31, 41, 51, 61, 71, 81, 91, 101, 111, 121, 128]\n",
    "else:\n",
    "    heads_perturbed = [0, 11, 111, 211, 311, 411, 511, 611, 711, 811, 911, 1011, 1111, 1211, 1300]\n",
    "\n",
    "# Define the classes\n",
    "Color = \"red, blue, green, yellow, black, white, purple, gray, pink, brown\".split(\", \")\n",
    "Animals = \"cat, dog, rabbit, frog, bird, squirrel, deer, lion, penguin, horse\".split(\", \")\n",
    "Fruits_and_Vegetables = \"lemons, bananas, apples, oranges, blueberries, carrots, broccoli, tomatoes, potatoes, grapes\".split(\", \")\n",
    "Image_Style = \"cubist, pop art, steampunk, impressionist, black-and-white, watercolor, cartoon, minimalist, sepia, sketch\".split(\", \")\n",
    "Material = \"glass, copper, marble, jade, gold, basalt, silver, clay, paper, leather\".split(\", \")\n",
    "Nature_Scenes = \"forest, desert, beach, waterfall, mountain, canyon, glacier, coral reef, jungle, lake\".split(\", \")\n",
    "Weather_Conditions = \"snowy, rainy, foggy, stormy\".split(\", \")\n",
    "Geometric_Patterns = \"polka-dot, leopard, stripe, greek-key, plaid\".split(\", \")\n",
    "Furniture = \"bed, table, chair, sofa, recliner, bookshelf, dresser, wardrobe, coffee table, TV stand\".split(\", \")\n",
    "Electronics = \"smartphone, laptop, tablet, smart TV, digital camera, drone, desktop computer, microwave, refrigerator, smartwatch\".split(\", \")\n",
    "Objects_A = \"car, bench, bowl, ballon, ball\".split(\", \")\n",
    "Objects_B = \"bowl, cup, table, ball, teapot\".split(\", \")\n",
    "Objects_C = \"T-shirt, pillow, wallpaper, umbrella, blanket\".split(\", \")\n",
    "Animals_A = \"cat, dog, rabbit, frog, bird\".split(\", \")\n",
    "Others = \"castle, mountain, cityscape, farmland, forest\".split(\", \")\n",
    "Animals_ood = \"rabbit, frog, sheep, pig, chicken, dolphin, goat, duck, deer, fox\".split(\", \")\n",
    "Color_ood = \"coral, beige, violet, cyan, magenta, indigo, orange, turquoise, teal, khaki\".split(\", \")\n",
    "Material_ood = \"copper, marble, jade, gold, basalt, silver, clay, steel, tin, bronze\".split(\", \")\n",
    "Fruits_and_Vegetables_ood = \"lemons, blueberries, onions, raspberries, pineapples, cherries, cucumbers, bell peppers, cauliflowers, mangoes\".split(\", \")\n",
    "Nature_Scenes_ood = \"glacier, coral reef, swamp, pond, fjord, rainforest, grassland, marsh, creek, island\".split(\", \")\n",
    "Tableware = \"salad bowl, serving platter, bread basket, fondue pot, spoon, fork, nut dish, coffee pot, tureen, chafing dish\".split(\",\")\n",
    "\n",
    "# Define seeds\n",
    "seeds = [10, 20, 30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:21<00:00,  4.04s/it]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Calculate CLIP image-text similarity and save the results as a csv file\"\"\"\n",
    "# --- Change only the following variables --- #\n",
    "exp_nums = [5]\n",
    "main_items_list = [Image_Style]\n",
    "categories = \"Image Style\".split(\", \")\n",
    "# ------------------------------------------- #\n",
    "\n",
    "name = \"wo_category\"\n",
    "\n",
    "\n",
    "for idx, exp_num in enumerate(exp_nums):\n",
    "    main_items = main_items_list[idx]\n",
    "    category = categories[idx]\n",
    "\n",
    "    directory = f\"./hp_outputs/exp_{exp_num}\"\n",
    "    subdirectories = [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]\n",
    "    scores = dict()\n",
    "\n",
    "    for subdirectory in tqdm(subdirectories):\n",
    "        # Prepare the text\n",
    "        for item in main_items:\n",
    "            if item in subdirectory:\n",
    "                prompt = item\n",
    "                break\n",
    "        if name == \"wo_category\":\n",
    "            prompt = [prompt]\n",
    "        elif name == \"w_category\":\n",
    "            prompt = [f\"{category}: {prompt}\"]\n",
    "        text = tokenizer(prompt * len(seeds)).to(device)\n",
    "\n",
    "        # Prepare the images\n",
    "        for order in [\"top\", \"bottom\"]:\n",
    "            subdirectory_order = os.path.join(directory, subdirectory, order)\n",
    "            suborders = [name for name in os.listdir(subdirectory_order) if os.path.isdir(os.path.join(subdirectory_order, name))]\n",
    "            for suborder in suborders: \n",
    "                images = []\n",
    "                for seed in seeds:\n",
    "                    image = preprocess(Image.open(os.path.join(subdirectory_order, suborder, f\"{seed}.png\"))).unsqueeze(0).to(device)\n",
    "                    images.append(image)\n",
    "                images = torch.cat(images, dim=0)\n",
    "                with torch.no_grad():\n",
    "                    image_features = model.encode_image(images)\n",
    "                    text_features = model.encode_text(text)\n",
    "                    score = nnf.cosine_similarity(image_features, text_features).mean(dim=0).item()\n",
    "                    if suborder not in scores:\n",
    "                        scores[suborder] = score\n",
    "                    else:\n",
    "                        scores[suborder] += score\n",
    "        \n",
    "    for key, _ in scores.items():\n",
    "        scores[key] /= len(subdirectories)\n",
    "\n",
    "    # Save as csv file\n",
    "    os.makedirs('./hp_results', exist_ok=True)\n",
    "    csv_file = f'./hp_results/hp_scores_{name}_exp_{exp_num}.csv'\n",
    "    keys = [f\"top_{num_heads}\" for num_heads in heads_perturbed]\n",
    "    keys += [f\"bottom_{num_heads}\" for num_heads in heads_perturbed] \n",
    "\n",
    "    with open(csv_file, mode='w', newline='') as file:\n",
    "        writer = csv.writer(file)\n",
    "        writer.writerow(['Key', 'Value'])  # Write the header\n",
    "        for key in keys:\n",
    "            writer.writerow([key, f\"{scores[key]:.4f}\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diffuser",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
