{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import glob\n",
    "import torchvision\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "import torch\n",
    "\n",
    "\n",
    "\n",
    "import torchvision.transforms as transforms"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/PycharmProjects/IDGEN/venv/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "transform = transforms.Compose([\n",
    "                                # transforms.PILToTensor(),\n",
    "                                transforms.ToTensor(),\n",
    "                                transforms.Resize((128,128))])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "main_path='/root/PycharmProjects/STAR_GAN/IDGEN/original_celeba_128'\n",
    "star_path='/root/PycharmProjects/STAR_GAN/IDGEN/stargan_celeba_128'\n",
    "egdse_path='/root/PycharmProjects/STAR_GAN/IDGEN/egsde_celeba_256'\n",
    "\n",
    "\n",
    "main_list= []\n",
    "star_list = []\n",
    "egdse_list=[]\n",
    "\n",
    "\n",
    "num_images=848\n",
    "for iter in range(num_images):\n",
    "    fname0= f'{main_path}/img{iter}.png'\n",
    "    fname1= f'{star_path}/img{iter}.png'\n",
    "    fname2= f'{egdse_path}/img{iter}.png'\n",
    "\n",
    "\n",
    "    if os.path.isfile(fname0) and os.path.isfile(fname1) and os.path.isfile(fname2):\n",
    "\n",
    "        im=Image.open(fname0)\n",
    "        im= transform(im).to('cuda')\n",
    "        main_list.append(im.unsqueeze(0))\n",
    "\n",
    "        im=Image.open(fname1)\n",
    "        im= transform(im).to('cuda')\n",
    "        star_list.append(im.unsqueeze(0))\n",
    "\n",
    "        im=Image.open(fname2)\n",
    "        im= transform(im).to('cuda')\n",
    "        egdse_list.append(im.unsqueeze(0))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "outputs": [
    {
     "data": {
      "text/plain": "848"
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(egdse_list)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "outputs": [],
   "source": [
    "main_images = torch.cat(main_list)\n",
    "stargan_images = torch.cat(star_list)\n",
    "egsde_images = torch.cat(egdse_list)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "outputs": [
    {
     "data": {
      "text/plain": "420"
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(egsde_images)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def plot_image_ara(img_ara, folder=None, title=None):\n",
    "    rows=img_ara.shape[0]\n",
    "    cols=img_ara.shape[1]\n",
    "\n",
    "    print(rows,cols)\n",
    "\n",
    "    f, axarr = plt.subplots(rows, cols, figsize=(cols, rows), squeeze=False)\n",
    "    for c in range(cols):\n",
    "\n",
    "        if c==0:\n",
    "            # axarr[r, c].get_yaxis().set_ticks([ '','x', ''])\n",
    "            # axarr[r, c].get_yaxis().set_ticks([x for x in range(-10,11)])\n",
    "            axarr[0, c].set_yticklabels( ['', '', 'Original'], rotation=45 ,fontsize=12 )\n",
    "            axarr[1, c].set_yticklabels( ['', '', 'StarGAN'], rotation=45 ,fontsize=12)\n",
    "            axarr[2, c].set_yticklabels( ['', '', 'EGDSE'], rotation=45 ,fontsize=12)\n",
    "\n",
    "        for r in range(rows):\n",
    "            if c>0:\n",
    "                axarr[r, c].get_yaxis().set_ticks([])\n",
    "\n",
    "            axarr[r, c].get_xaxis().set_ticks([])\n",
    "            img= img_ara[r][c].cpu().detach().numpy()\n",
    "            img= np.transpose(img, (1,2,0))\n",
    "            axarr[r, c].imshow(img)\n",
    "\n",
    "\n",
    "        f.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)\n",
    "\n",
    "    if folder==None:\n",
    "\n",
    "        plt.show()\n",
    "    else:\n",
    "        os.makedirs(folder, exist_ok=True)\n",
    "        plt.savefig(f'{folder}/{title}.png', bbox_inches='tight')\n",
    "\n",
    "    plt.close()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "for st in range(200,300,10):\n",
    "# st=120\n",
    "    en= st+8\n",
    "\n",
    "    all_img= torch.cat([main_images[st:en].unsqueeze(0), stargan_images[st:en].unsqueeze(0), egsde_images[st:en].unsqueeze(0)])\n",
    "    # plot_image_ara(all_img)\n",
    "    plot_image_ara(all_img,'/root/PycharmProjects/STAR_GAN/IDGEN/PLOTS/',f'comparison{st}_{en}')\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classifier loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/PycharmProjects/IDGEN/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from IDGEN.constantFunctions import get_classifier\n",
    "\n",
    "label_path = \"/root/Datasets/CelebAMask-HQ/CelebAMask-HQ-attribute-anno.txt\"\n",
    "attributes = open(label_path).readlines()[1].split(' ')\n",
    "attributes[-1] = attributes[-1].strip('\\n')\n",
    "classifier, trainer = get_classifier(attributes, IMAGE_SIZE=128)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "def get_prediction(classifier, trainer, images):\n",
    "\n",
    "    transform = transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    data_list = []\n",
    "    for img in images:\n",
    "        lbl = torch.zeros(40, 1)\n",
    "        data_list.append([transform(img), lbl])\n",
    "\n",
    "    predict_loader = DataLoader(dataset=data_list, batch_size=1, shuffle=False)\n",
    "    prediction = trainer.predict(classifier, predict_loader)  # without fine-tuning\n",
    "    all=[]\n",
    "    for idx, data_input in enumerate(prediction):\n",
    "        pred= data_input[2][0]\n",
    "        all.append(pred)\n",
    "    return all\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "/root/PycharmProjects/IDGEN/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/plain": "Predicting: |          | 0/? [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "caef0e9b896c499ba14d3c35639931d1"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "data": {
      "text/plain": "Predicting: |          | 0/? [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "bc5bee5d9a2f44269a0b0e4f2bfc1c9f"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "data": {
      "text/plain": "Predicting: |          | 0/? [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "2c375989790845ab98112b6aef83b31c"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_main = get_prediction(classifier, trainer, main_images)\n",
    "pred_star = get_prediction(classifier, trainer, stargan_images)\n",
    "pred_egsde = get_prediction(classifier, trainer, egsde_images)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "####\n",
    "def get_new_att(pred1, pred2):\n",
    "    increased =  {}\n",
    "    for lb in attributes:\n",
    "        increased[lb] = 0\n",
    "\n",
    "    for st, en in zip(pred1, pred2):\n",
    "        diff = set(en) - set(st)\n",
    "        for lb in diff:\n",
    "            increased[lb] += 1\n",
    "\n",
    "    for lb in increased:\n",
    "        increased[lb] = increased[lb]/ (len(pred1))*100\n",
    "    increased =dict(sorted(increased.items(), key=lambda item: item[1]))\n",
    "\n",
    "    return increased\n",
    "\n",
    "x= get_new_att(pred_main, pred_star)\n",
    "y= get_new_att(pred_main, pred_egsde)\n",
    "\n",
    "\n",
    "x = dict(reversed(x.items()))\n",
    "y = dict(reversed(y.items()))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "outputs": [
    {
     "data": {
      "text/plain": "{'Wearing_Lipstick': 33.72641509433962,\n 'No_Beard': 26.650943396226417,\n 'High_Cheekbones': 20.28301886792453,\n 'Arched_Eyebrows': 20.28301886792453,\n 'Heavy_Makeup': 18.160377358490564,\n 'Receding_Hairline': 12.382075471698114,\n 'Attractive': 11.556603773584905,\n 'Black_Hair': 10.731132075471699,\n 'Pointy_Nose': 9.19811320754717,\n 'Blurry': 7.783018867924528,\n 'Young': 6.721698113207547,\n 'Smiling': 6.25,\n 'Oval_Face': 5.89622641509434,\n 'Bags_Under_Eyes': 5.89622641509434,\n 'Mouth_Slightly_Open': 5.660377358490567,\n 'Gray_Hair': 5.306603773584905,\n 'Brown_Hair': 4.952830188679245,\n 'Wearing_Earrings': 2.5943396226415096,\n 'Narrow_Eyes': 2.1226415094339623,\n 'Big_Nose': 2.1226415094339623,\n 'Bangs': 2.1226415094339623,\n 'Blond_Hair': 1.2971698113207548,\n 'Wavy_Hair': 1.179245283018868,\n 'Straight_Hair': 0.9433962264150944,\n 'Eyeglasses': 0.7075471698113208,\n 'Double_Chin': 0.7075471698113208,\n 'Big_Lips': 0.589622641509434,\n 'Wearing_Hat': 0.3537735849056604,\n 'Chubby': 0.3537735849056604,\n 'Rosy_Cheeks': 0.2358490566037736,\n 'Bushy_Eyebrows': 0.2358490566037736,\n 'Bald': 0.2358490566037736,\n 'Wearing_Necktie': 0.0,\n 'Wearing_Necklace': 0.0,\n 'Sideburns': 0.0,\n 'Pale_Skin': 0.0,\n 'Mustache': 0.0,\n 'Male': 0.0,\n 'Goatee': 0.0,\n '5_o_Clock_Shadow': 0.0}"
     },
     "execution_count": 167,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "outputs": [
    {
     "data": {
      "text/plain": "{'Wearing_Lipstick': 81.60377358490565,\n 'Heavy_Makeup': 68.27830188679245,\n 'Arched_Eyebrows': 45.75471698113208,\n 'Oval_Face': 41.863207547169814,\n 'High_Cheekbones': 36.556603773584904,\n 'Attractive': 36.20283018867924,\n 'Wearing_Earrings': 33.9622641509434,\n 'No_Beard': 27.830188679245282,\n 'Young': 26.41509433962264,\n 'Mouth_Slightly_Open': 21.34433962264151,\n 'Smiling': 19.339622641509436,\n 'Straight_Hair': 16.27358490566038,\n 'Rosy_Cheeks': 11.556603773584905,\n 'Pointy_Nose': 11.43867924528302,\n 'Receding_Hairline': 10.141509433962264,\n 'Bangs': 8.962264150943396,\n 'Blond_Hair': 8.372641509433961,\n 'Brown_Hair': 7.547169811320755,\n 'Big_Lips': 4.009433962264151,\n 'Wavy_Hair': 3.4198113207547167,\n 'Black_Hair': 3.0660377358490565,\n 'Big_Nose': 2.358490566037736,\n 'Double_Chin': 1.4150943396226416,\n 'Bags_Under_Eyes': 1.4150943396226416,\n 'Eyeglasses': 1.179245283018868,\n 'Wearing_Necklace': 0.7075471698113208,\n 'Narrow_Eyes': 0.7075471698113208,\n 'Gray_Hair': 0.7075471698113208,\n 'Chubby': 0.7075471698113208,\n 'Wearing_Hat': 0.2358490566037736,\n 'Bushy_Eyebrows': 0.1179245283018868,\n 'Blurry': 0.1179245283018868,\n 'Bald': 0.1179245283018868,\n 'Wearing_Necktie': 0.0,\n 'Sideburns': 0.0,\n 'Pale_Skin': 0.0,\n 'Mustache': 0.0,\n 'Male': 0.0,\n 'Goatee': 0.0,\n '5_o_Clock_Shadow': 0.0}"
     },
     "execution_count": 168,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
