{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset.vimeo90k_septuplet_6 import VimeoSepTuplet\n",
    "from dataset.Davis_test import Davis\n",
    "from dataset.SNU_FILM import SNUFILM\n",
    "from model.SCubA import SCubA\n",
    "from model.SGuTA import SGuTA\n",
    "from model.FLAVR_arch import UNet_3D_3D as FLAVR\n",
    "from model_VFIT.VFIT_B import UNet_3D_3D as VFIT_B\n",
    "from model_VFIT.VFIT_S import UNet_3D_3D as VFIT_S\n",
    "import random\n",
    "import torch\n",
    "\n",
    "import myutils\n",
    "\n",
    "device = 'cuda:1'\n",
    "\n",
    "scuba = SCubA(in_channels=3, out_channels=3, n_feat=64, patch_size=(1,4,4), cube_size=(2,4,4), stage=3).to(device)\n",
    "scuba_best_model_path = 'saved_models_final/vimeo90K_septuplet_6/ScubA_Hola/model_best.pth'\n",
    "scuba.load_state_dict(torch.load(scuba_best_model_path)[\"state_dict\"] , strict=True)\n",
    "scuba.eval()\n",
    "\n",
    "sguta = SGuTA(in_channels=3, out_channels=3, n_feat=64, patch_size=(1,4,4), stage=3, num_frm=8).to(device)\n",
    "sguta_best_model_path = 'saved_models_final/vimeo90K_septuplet_6/SGuTA_3/model_best.pth'\n",
    "sguta.load_state_dict(torch.load(sguta_best_model_path)[\"state_dict\"] , strict=True)\n",
    "sguta.eval()\n",
    "\n",
    "\n",
    "vift_s = VFIT_S(n_inputs=4, joinType=\"concat\").to(device)\n",
    "vift_s_best_model_path = 'saved_models_final/VFIT/VFIT-S_best.pth'\n",
    "loadStateDict = torch.load(vift_s_best_model_path)[\"state_dict\"]\n",
    "loadStateDict = {k.partition(\"module.\")[-1]:v for k,v in loadStateDict.items()}\n",
    "vift_s.load_state_dict(loadStateDict , strict=True)\n",
    "vift_s.eval()\n",
    "\n",
    "vift_b = VFIT_B(n_inputs=4, joinType=\"concat\").to(device)\n",
    "vift_b_best_model_path = 'saved_models_final/VFIT/VFIT-B_best.pth'\n",
    "loadStateDict = torch.load(vift_b_best_model_path)[\"state_dict\"]\n",
    "loadStateDict = {k.partition(\"module.\")[-1]:v for k,v in loadStateDict.items()}\n",
    "vift_b.load_state_dict(loadStateDict , strict=True)\n",
    "vift_b.eval()\n",
    "\n",
    "flavr = FLAVR(\"unet_18\", n_inputs=4, n_outputs=1,joinType=\"concat\").to(device)\n",
    "# flavr = torch.nn.DataParallel(flavr).to(device)\n",
    "flavr_best_model_path = 'FLAVR_2x.pth'\n",
    "loadStateDict = torch.load(flavr_best_model_path)[\"state_dict\"]\n",
    "loadStateDict = {k.partition(\"module.\")[-1]:v for k,v in loadStateDict.items()}\n",
    "flavr.load_state_dict(loadStateDict , strict=True)\n",
    "flavr.eval()\n",
    "\n",
    "vimeo_set = VimeoSepTuplet('/home/esthen/Datasets/vimeo_septuplet', is_training=False)\n",
    "davis_set = Davis(\"/home/esthen/Datasets/Davis_test/\")\n",
    "\n",
    "snufilm_set = SNUFILM('/mnt/sdb5/SNU-FILM',mode=\"hard\", n_inputs=6)\n",
    "\n",
    "print(len(vimeo_set))\n",
    "print(len(snufilm_set))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "# snufilm_idx = random.randint(0,2848)\n",
    "# print(\"snufilm_idx:\",davis_idx)\n",
    "\n",
    "\n",
    "for snufilm_idx in tqdm(range(len(snufilm_set))):\n",
    "    snufilm_input = [img.unsqueeze(0).cuda() for img in snufilm_set[snufilm_idx][0]]\n",
    "    snufilm_gt = [gt.unsqueeze(0).cuda() for gt in snufilm_set[snufilm_idx][1]]\n",
    "    snufilm_overlay = snufilm_input[1]*0.5 + snufilm_input[4]*0.5\n",
    "\n",
    "    import matplotlib.pyplot as plt\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        \n",
    "        gt = torch.cat(snufilm_gt)\n",
    "        \n",
    "        out_scuba = scuba(snufilm_input)\n",
    "        out_scuba = torch.cat(out_scuba)\n",
    "\n",
    "        \n",
    "        out_sguta = sguta(snufilm_input)\n",
    "        out_sguta = torch.cat(out_sguta)   \n",
    "\n",
    "        out_flavr = flavr([snufilm_input[0],snufilm_input[2],snufilm_input[3],snufilm_input[5]])\n",
    "        out_flavr = torch.cat(out_flavr) \n",
    "    \n",
    "        out_vift_s = vift_s([snufilm_input[0],snufilm_input[2],snufilm_input[3],snufilm_input[5]])\n",
    "        out_vift_s = torch.cat(out_vift_s) \n",
    "        out_vift_b = vift_b([snufilm_input[0],snufilm_input[2],snufilm_input[3],snufilm_input[5]])\n",
    "        out_vift_b = torch.cat(out_vift_b) \n",
    "\n",
    "        psnr_scuba = myutils.calc_psnr(out_scuba,gt)\n",
    "        psnr_sguta = myutils.calc_psnr(out_sguta,gt)\n",
    "        psnr_vift_s = myutils.calc_psnr(out_vift_s,gt)\n",
    "        psnr_vift_b = myutils.calc_psnr(out_vift_b,gt)\n",
    "        psnr_flavr = myutils.calc_psnr(out_flavr,gt)\n",
    "        \n",
    "        if psnr_scuba > psnr_sguta > psnr_vift_b+0.5 > psnr_vift_s+0.5 > psnr_flavr+0.5:\n",
    "            save_image(out_scuba.squeeze(0), 'visual_result/snufilm/hard/{}_SCubA_{:.2f}.png'.format(snufilm_idx,psnr_scuba))\n",
    "            save_image(out_sguta.squeeze(0), 'visual_result/snufilm/hard/{}_SGuTA_{:.2f}.png'.format(snufilm_idx,psnr_sguta))\n",
    "            save_image(out_flavr.squeeze(0), 'visual_result/snufilm/hard/{}_FLAVR_{:.2f}.png'.format(snufilm_idx,psnr_scuba))\n",
    "            save_image(out_vift_s.squeeze(0), 'visual_result/snufilm/hard/{}_VFITS_{:.2f}.png'.format(snufilm_idx,psnr_vift_s))\n",
    "            save_image(out_vift_b.squeeze(0), 'visual_result/snufilm/hard/{}_VFITB_{:.2f}.png'.format(snufilm_idx,psnr_vift_b))\n",
    "            save_image(gt.squeeze(0), 'visual_result/snufilm/hard/{}_gt.png'.format(snufilm_idx))\n",
    "            save_image(snufilm_overlay.squeeze(0), 'visual_result/snufilm/hard/{}_overlay.png'.format(snufilm_idx))\n",
    "\n",
    "            print(\"images of index [{}] saved\".format(snufilm_idx))\n",
    "        # plt.figure(0)\n",
    "        # plt.imshow(out_scuba.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_scuba, gt, psnrs, ssims)\n",
    "        # print(\"Scuba_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))\n",
    "\n",
    "        # print(psnr_scuba)\n",
    "        \n",
    "        # plt.figure(1)\n",
    "        # plt.imshow(out_sguta.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_sguta, gt, psnrs, ssims)\n",
    "        # print(\"Feat_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))   \n",
    "\n",
    "        # print(psnr_sguta) \n",
    "\n",
    "\n",
    "        # plt.figure(2)\n",
    "        # plt.imshow(out_vift_s.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_vift_s, gt, psnrs, ssims)\n",
    "        # print(\"VFIT_S_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "\n",
    "        # print(psnr_vift_s)\n",
    "        \n",
    "        # plt.figure(3)\n",
    "        # plt.imshow(out_vift_b.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_vift_b, gt, psnrs, ssims)\n",
    "        # print(\"VFIT_B_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "\n",
    "        # print(psnr_vift_b)\n",
    "                \n",
    "        # plt.figure(4)\n",
    "        # plt.imshow(out_flavr.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_flavr, gt, psnrs, ssims)\n",
    "        # print(\"FLAVR_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "\n",
    "        # print(psnr_flavr)\n",
    "        \n",
    "        # plt.figure(5)\n",
    "        # plt.imshow(gt.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        \n",
    "        # plt.figure(6)\n",
    "        # plt.imshow(snufilm_overlay.squeeze(0).cpu().permute(1, 2, 0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "\n",
    "\n",
    "for vimeo_idx in tqdm(range(7823)):\n",
    "    vimeo_input = [img.unsqueeze(0).cuda() for img in vimeo_set[vimeo_idx][0]]\n",
    "    vimeo_gt = [gt.unsqueeze(0).cuda() for gt in vimeo_set[vimeo_idx][1]]\n",
    "    vimeo_overlay = vimeo_input[1]*0.5 + vimeo_input[4]*0.5\n",
    "\n",
    "    import matplotlib.pyplot as plt\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        \n",
    "        gt = torch.cat(vimeo_gt)\n",
    "        \n",
    "        out_scuba = scuba(vimeo_input)\n",
    "        out_scuba = torch.cat(out_scuba)\n",
    "\n",
    "        \n",
    "        out_sguta = sguta(vimeo_input)\n",
    "        out_sguta = torch.cat(out_sguta)   \n",
    "\n",
    "        out_flavr = flavr([vimeo_input[0],vimeo_input[2],vimeo_input[3],vimeo_input[5]])\n",
    "        \n",
    "        out_flavr = torch.cat(out_flavr) \n",
    "    \n",
    "        out_vift_s = vift_s([vimeo_input[0],vimeo_input[2],vimeo_input[3],vimeo_input[5]])\n",
    "        out_vift_s = torch.cat(out_vift_s) \n",
    "        out_vift_b = vift_b([vimeo_input[0],vimeo_input[2],vimeo_input[3],vimeo_input[5]])\n",
    "        out_vift_s = torch.cat(out_vift_s) \n",
    "\n",
    "        psnr_scuba = myutils.calc_psnr(out_scuba,gt)\n",
    "        psnr_sguta = myutils.calc_psnr(out_sguta,gt)\n",
    "        psnr_vift_s = myutils.calc_psnr(out_vift_s,gt)\n",
    "        psnr_vift_b = myutils.calc_psnr(out_vift_b,gt)\n",
    "        psnr_flavr = myutils.calc_psnr(out_flavr,gt)\n",
    "        \n",
    "        if (psnr_scuba > psnr_sguta > psnr_vift_b+0.3 > psnr_vift_s+0.6 > psnr_flavr+0.9) and psnr_flavr < 32:\n",
    "            save_image(out_scuba.squeeze(0), 'visual_result/vimeo/{}_scuba_{:.2f}.png'.format(vimeo_idx,psnr_scuba))\n",
    "            save_image(out_sguta.squeeze(0), 'visual_result/vimeo/{}_sguta_{:.2f}.png'.format(vimeo_idx,psnr_sguta))\n",
    "            save_image(out_flavr.squeeze(0), 'visual_result/vimeo/{}_flavr_{:.2f}.png'.format(vimeo_idx,psnr_flavr))\n",
    "            save_image(out_vift_b.squeeze(0), 'visual_result/vimeo/{}_vift_b_{:.2f}.png'.format(vimeo_idx,psnr_vift_b))\n",
    "            save_image(out_vift_s.squeeze(0), 'visual_result/vimeo/{}_vift_s_{:.2f}.png'.format(vimeo_idx,psnr_vift_s))\n",
    "            save_image(gt.squeeze(0), 'visual_result/vimeo/{}_gt.png'.format(vimeo_idx))\n",
    "            save_image(vimeo_overlay.squeeze(0), 'visual_result/vimeo/{}_overlay.png'.format(vimeo_idx))\n",
    "            print(\"images of index [{}] saved\".format(vimeo_idx))\n",
    "        # plt.figure(0)\n",
    "        # plt.imshow(out_scuba.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_scuba, gt, psnrs, ssims)\n",
    "        # print(\"Scuba_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))\n",
    "\n",
    "        # print(psnr_scuba)\n",
    "        \n",
    "        # plt.figure(1)\n",
    "        # plt.imshow(out_sguta.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_sguta, gt, psnrs, ssims)\n",
    "        # print(\"Feat_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))   \n",
    "\n",
    "        # print(psnr_sguta) \n",
    "\n",
    "\n",
    "        # plt.figure(2)\n",
    "        # plt.imshow(out_vift_s.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_vift_s, gt, psnrs, ssims)\n",
    "        # print(\"VFIT_S_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "\n",
    "        # print(psnr_vift_s)\n",
    "        \n",
    "        # plt.figure(3)\n",
    "        # plt.imshow(out_vift_b.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_vift_b, gt, psnrs, ssims)\n",
    "        # print(\"VFIT_B_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "\n",
    "        # print(psnr_vift_b)\n",
    "                \n",
    "        # plt.figure(4)\n",
    "        # plt.imshow(out_flavr.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        # _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "        # myutils.eval_metrics(out_flavr, gt, psnrs, ssims)\n",
    "        # print(\"FLAVR_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "\n",
    "        # print(psnr_flavr)\n",
    "        \n",
    "        # plt.figure(5)\n",
    "        # plt.imshow(gt.squeeze(0).cpu().permute(1, 2, 0))\n",
    "        \n",
    "        # plt.figure(6)\n",
    "        # plt.imshow(vimeo_overlay.squeeze(0).cpu().permute(1, 2, 0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vimeo_idx = random.randint(0,7823)\n",
    "print(\"vimeo_idx:\", vimeo_idx)\n",
    "device = 'cuda'\n",
    "\n",
    "\n",
    "vimeo_input = [img.unsqueeze(0).cuda() for img in vimeo_set[vimeo_idx][0]]\n",
    "vimeo_gt = [gt.unsqueeze(0).cuda() for gt in vimeo_set[vimeo_idx][1]]\n",
    "vimeo_overlay = vimeo_input[1]*0.5 + vimeo_input[4]*0.5\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "with torch.no_grad():\n",
    "    \n",
    "    \n",
    "    gt = torch.cat(vimeo_gt)\n",
    "    \n",
    "    out_scuba = scuba(vimeo_input)\n",
    "    out_scuba = torch.cat(out_scuba)\n",
    "\n",
    "    \n",
    "    out_sguta = sguta(vimeo_input)\n",
    "    out_sguta = torch.cat(out_sguta)   \n",
    "\n",
    "    out_flavr = flavr([vimeo_input[0],vimeo_input[2],vimeo_input[3],vimeo_input[5]])\n",
    "    out_flavr = torch.cat(out_flavr) \n",
    " \n",
    "     \n",
    "    plt.figure(0)\n",
    "    plt.imshow(out_scuba.squeeze(0).cpu().permute(1, 2, 0))\n",
    "    _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "    myutils.eval_metrics(out_scuba, gt, psnrs, ssims)\n",
    "    print(\"Scuba_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))\n",
    "    x = myutils.calc_psnr(out_scuba,gt)\n",
    "    print(x)\n",
    "    \n",
    "    plt.figure(1)\n",
    "    plt.imshow(out_sguta.squeeze(0).cpu().permute(1, 2, 0))\n",
    "    _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "    myutils.eval_metrics(out_sguta, gt, psnrs, ssims)\n",
    "    print(\"Feat_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))   \n",
    "    y = myutils.calc_psnr(out_sguta,gt)\n",
    "    print(y) \n",
    "    \n",
    "\n",
    "    out_vift_s = vift_s([vimeo_input[0],vimeo_input[2],vimeo_input[3],vimeo_input[5]])\n",
    "    plt.figure(2)\n",
    "    plt.imshow(out_vift_s.squeeze(0).cpu().permute(1, 2, 0))\n",
    "    _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "    myutils.eval_metrics(out_vift_s, gt, psnrs, ssims)\n",
    "    print(\"VFIT_S_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "    z = myutils.calc_psnr(out_vift_s,gt)\n",
    "    print(z)\n",
    "    \n",
    "\n",
    "    out_vift_b = vift_b([vimeo_input[0],vimeo_input[2],vimeo_input[3],vimeo_input[5]])\n",
    "    plt.figure(3)\n",
    "    plt.imshow(out_vift_b.squeeze(0).cpu().permute(1, 2, 0))\n",
    "    _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "    myutils.eval_metrics(out_vift_b, gt, psnrs, ssims)\n",
    "    print(\"VFIT_B_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "    z = myutils.calc_psnr(out_vift_b,gt)\n",
    "    print(z)\n",
    "            \n",
    "    plt.figure(4)\n",
    "    plt.imshow(out_flavr.squeeze(0).cpu().permute(1, 2, 0))\n",
    "    _, psnrs, ssims = myutils.init_meters('1*L1')\n",
    "    myutils.eval_metrics(out_flavr, gt, psnrs, ssims)\n",
    "    print(\"FLAVR_PSNR: %f, SSIM: %fn\" %(psnrs.avg, ssims.avg))     \n",
    "    z = myutils.calc_psnr(out_flavr,gt)\n",
    "    print(z)\n",
    "    \n",
    "    plt.figure(5)\n",
    "    plt.imshow(gt.squeeze(0).cpu().permute(1, 2, 0))\n",
    "    \n",
    "    plt.figure(6)\n",
    "    plt.imshow(vimeo_overlay.squeeze(0).cpu().permute(1, 2, 0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.utils import save_image\n",
    "def save_images():\n",
    "    save_image(out_scuba.squeeze(0), 'visual_result/vimeo/scuba_{}.png'.format(vimeo_idx))\n",
    "    save_image(out_sguta.squeeze(0), 'visual_result/vimeo/sguta_{}.png'.format(vimeo_idx))\n",
    "    save_image(out_flavr.squeeze(0), 'visual_result/vimeo/flavr_{}.png'.format(vimeo_idx))\n",
    "    save_image(out_vift_b.squeeze(0), 'visual_result/vimeo/vift_b_{}.png'.format(vimeo_idx))\n",
    "    save_image(out_vift_s.squeeze(0), 'visual_result/vimeo/vift_s{}.png'.format(vimeo_idx))\n",
    "    save_image(gt.squeeze(0), 'visual_result/vimeo/gt_{}.png'.format(vimeo_idx))\n",
    "    save_image(davis_overlay.squeeze(0), 'visual_result/vimeo/overlay_{}.png'.format(vimeo_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "# snufilm_idx = random.randint(0,2848)\n",
    "# print(\"davis_idx:\",snufilm_idx)\n",
    "\n",
    "\n",
    "for davis_idx in tqdm(range(len(davis_set))):\n",
    "    davis_input = [img.unsqueeze(0).cuda() for img in davis_set[davis_idx][0]]\n",
    "    davis_gt = [gt.unsqueeze(0).cuda() for gt in davis_set[davis_idx][1]]\n",
    "    davis_overlay = davis_input[1]*0.5 + davis_input[4]*0.5\n",
    "\n",
    "    import matplotlib.pyplot as plt\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        \n",
    "        gt = torch.cat(davis_gt)\n",
    "        \n",
    "        out_scuba = scuba(davis_input)\n",
    "        out_scuba = torch.cat(out_scuba)\n",
    "\n",
    "        \n",
    "        out_sguta = sguta(davis_input)\n",
    "        out_sguta = torch.cat(out_sguta)   \n",
    "\n",
    "        out_flavr = flavr([davis_input[0],davis_input[2],davis_input[3],davis_input[5]])\n",
    "        out_flavr = torch.cat(out_flavr) \n",
    "    \n",
    "        out_vift_s = vift_s([davis_input[0],davis_input[2],davis_input[3],davis_input[5]])\n",
    "        out_vift_s = torch.cat(out_vift_s) \n",
    "        out_vift_b = vift_b([davis_input[0],davis_input[2],davis_input[3],davis_input[5]])\n",
    "        out_vift_b = torch.cat(out_vift_b) \n",
    "\n",
    "        psnr_scuba = myutils.calc_psnr(out_scuba,gt)\n",
    "        psnr_sguta = myutils.calc_psnr(out_sguta,gt)\n",
    "        psnr_vift_s = myutils.calc_psnr(out_vift_s,gt)\n",
    "        psnr_vift_b = myutils.calc_psnr(out_vift_b,gt)\n",
    "        psnr_flavr = myutils.calc_psnr(out_flavr,gt)\n",
    "        \n",
    "        if psnr_scuba > psnr_sguta > psnr_vift_b+0.5 > psnr_vift_s+0.5 > psnr_flavr+0.5:\n",
    "            save_image(out_scuba.squeeze(0), 'visual_result/davis/{}_SCubA_{:.2f}.png'.format(davis_idx,psnr_scuba))\n",
    "            save_image(out_sguta.squeeze(0), 'visual_result/davis/{}_SGuTA_{:.2f}.png'.format(davis_idx,psnr_sguta))\n",
    "            save_image(out_flavr.squeeze(0), 'visual_result/davis/{}_FLAVR_{:.2f}.png'.format(davis_idx,psnr_scuba))\n",
    "            save_image(out_vift_s.squeeze(0), 'visual_result/davis/{}_VFITS_{:.2f}.png'.format(davis_idx,psnr_vift_s))\n",
    "            save_image(out_vift_b.squeeze(0), 'visual_result/davis/{}_VFITB_{:.2f}.png'.format(davis_idx,psnr_vift_b))\n",
    "            save_image(gt.squeeze(0), 'visual_result/davis/{}_gt.png'.format(davis_idx))\n",
    "            save_image(davis_overlay.squeeze(0), 'visual_result/davis/{}_overlay.png'.format(davis_idx))\n",
    "\n",
    "            print(\"images of index [{}] saved\".format(davis_idx))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 ('esthen')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3933cbb7e64290da9f72fcee5df6f9f8f4b1290734ecabea63894f133005982c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
