{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93fb9b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "import os\n",
    "os.chdir('../../SDv1.5')\n",
    "from metrics_utils import VGGCalculator, CLIPScorer, calculate_aesthetic_score, DINOStyleScorer\n",
    "from notebooks.prompts_dict import prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6a3bbdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "VGG_Scorer = VGGCalculator()\n",
    "CLIP_Scorer = CLIPScorer()\n",
    "DINO_Scorer = DINOStyleScorer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d438645",
   "metadata": {},
   "outputs": [],
   "source": [
    "skip_steps = 10\n",
    "CFG = 15\n",
    "swap_guidance_scale = 15\n",
    "sparse_weight = 0\n",
    "interpolation=[0.1]\n",
    "\n",
    "for int_value in interpolation:\n",
    "  for j in [1,2,3,4,5]:\n",
    "    clip_loss = 0\n",
    "    vgg_loss = 0\n",
    "    aesthetic_score = 0\n",
    "    dino_loss = 0\n",
    "    ref_image_path = f'neurips_demo/kv_swap/CFG{CFG}_sfg{swap_guidance_scale}_skip{skip_steps}/style{j}_{sparse_weight}_int_{int_value}'\n",
    "    image_path_list = os.listdir(ref_image_path)\n",
    "    for i, image_path in enumerate(image_path_list):\n",
    "        tmp_dir = os.listdir(os.path.join(ref_image_path, prompts[i+1]))[0]\n",
    "        full_image_path = os.path.join(ref_image_path, prompts[i+1],tmp_dir,'out_transfer---seed_42.png')\n",
    "        style_image_path = os.path.join(ref_image_path, prompts[i+1],tmp_dir,'out_style---seed_42.png')\n",
    "\n",
    "        clip_loss += CLIP_Scorer(full_image_path, 'a sketch of '+prompts[i+1])\n",
    "        vgg_loss += VGG_Scorer.calculate_similarity(full_image_path, style_image_path)\n",
    "        aesthetic_score+= calculate_aesthetic_score(full_image_path)\n",
    "        dino_loss += DINO_Scorer(full_image_path, style_image_path)\n",
    "        \n",
    "    print(f\"Style {j} int {int_value}, CLIP loss: {clip_loss/len(image_path_list)}, VGG loss: {vgg_loss/len(image_path_list)}, \\\n",
    "          asesthetic_score: {aesthetic_score/len(image_path_list)}, dino_loss: {dino_loss/len(image_path_list)}\")\n",
    "  print('----------------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98a01836",
   "metadata": {},
   "outputs": [],
   "source": [
    "skip_steps = 10\n",
    "CFG = 15\n",
    "swap_guidance_scale = 25\n",
    "sparse_weight = 60\n",
    "interpolation=[0.05]\n",
    "\n",
    "for int_value in interpolation:\n",
    "  for j in [6]:\n",
    "    clip_loss = 0\n",
    "    vgg_loss = 0\n",
    "    aesthetic_score = 0\n",
    "    dino_loss = 0\n",
    "\n",
    "    ref_image_path = f'neurips_demo/CFG{CFG}_sfg{swap_guidance_scale}_skip{skip_steps}/style{j}_{sparse_weight}_int_{int_value}'\n",
    "    image_path_list = os.listdir(ref_image_path)\n",
    "    for i, image_path in enumerate(image_path_list):\n",
    "        tmp_dir = os.listdir(os.path.join(ref_image_path, prompts[i+1]))[0]\n",
    "        full_image_path = os.path.join(ref_image_path, prompts[i+1],tmp_dir,'out_transfer---seed_42.png')\n",
    "        style_image_path = os.path.join(ref_image_path, prompts[i+1],tmp_dir,'out_style---seed_42.png')\n",
    "\n",
    "        clip_loss += CLIP_Scorer(full_image_path, 'a sketch of '+prompts[i+1])\n",
    "        vgg_loss += VGG_Scorer.calculate_similarity(full_image_path, style_image_path)\n",
    "        aesthetic_score+= calculate_aesthetic_score(full_image_path)\n",
    "        dino_loss += DINO_Scorer(full_image_path, style_image_path)\n",
    "        \n",
    "    print(f\"Style {j} int {int_value}, CLIP loss: {clip_loss/len(image_path_list)}, VGG loss: {vgg_loss/len(image_path_list)}, \\\n",
    "          asesthetic_score: {aesthetic_score/len(image_path_list)},dino_loss: {dino_loss/len(image_path_list)}\")\n",
    "  print('----------------------------------')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cross_image",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
