{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from scipy.linalg import sqrtm\n",
    "#code the FID \n",
    "import pytorch_fid as fid\n",
    "import os \n",
    "from scipy import linalg\n",
    "path_ori=\"./Cifar_10_ori.pt\"\n",
    "path_com=\"./output_vector\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2048)\n",
      "1.562034372940417\n",
      "102.9493076414865 Cifar-10-diff_class_unet-150-001-6-timesteps-10-sam-1\n",
      "103.00896214866238 Cifar-10-test-150-001-30-timesteps-50-sam-1\n",
      "103.87331986041738 Cifar-10-diff_class_unet-150-001-48-timesteps-80-sam-1\n",
      "108.18734985869978 Cifar-10-diff_class-75-001-15\n",
      "108.21147349071329 Cifar-10-diff_class-75-001-5\n",
      "108.23811817579103 Cifar-10-diff_class-75-001-10-timesteps-10-sam-0\n",
      "108.24947880517959 Cifar-10-diff_class-35-001-15\n",
      "108.25282327519496 Cifar-10-diff_class_152-75-001-10-timesteps-10-sam-1\n",
      "108.253880054648 Cifar-10-diff_class-0-001-1\n",
      "108.26486336540852 Cifar-10-diff_class-75-001-1\n",
      "108.26777094869891 Cifar-10-diff_class-50\n",
      "108.26885855119644 Cifar-10-diff_class-75-001-10-timesteps-30-sam-True\n",
      "108.3136770768738 Cifar-10-diff_class_unet-75-001-10-timesteps-10-sam-1\n",
      "108.32016655765622 Cifar-10-diff_class_152-75-001-5-timesteps-10-sam-1\n",
      "108.33892610868422 Cifar-10-diff_class-75-001\n",
      "108.3412034961526 Cifar-10-diff_class-150-001\n",
      "108.38151825685881 Cifar-10-diff_class-ori\n",
      "108.44888283133906 Cifar-10-diff_class-150\n",
      "108.77463215644431 Cifar-10-diff_class-35-001-25\n",
      "112.24670776893703 Cifar-10-diff_class-75-001-10-timesteps-50-sam-True\n",
      "112.66637803545116 Cifar-10-test\n",
      "113.94146148948602 Cifar-10-diff_class-75-001-10-timesteps-10-sam-True\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "which_model=\"fid\"\n",
    "ori_vec=torch.load(path_com+\"/\"+\"Cifar-10-ori\"+\"/\"+which_model+\"/vec.pt\").numpy()[:10000]\n",
    "print(ori_vec.shape)\n",
    "ori_mu=np.mean(ori_vec,axis=0)\n",
    "ori_sigma=np.cov(ori_vec,rowvar=False)\n",
    "#print(ori_mu.shape,ori_sigma.shape)\n",
    "def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n",
    "    \"\"\"Numpy implementation of the Frechet Distance.\n",
    "    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n",
    "    and X_2 ~ N(mu_2, C_2) is\n",
    "            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n",
    "\n",
    "    Stable version by Dougal J. Sutherland.\n",
    "\n",
    "    Params:\n",
    "    -- mu1   : Numpy array containing the activations of a layer of the\n",
    "               inception net (like returned by the function 'get_predictions')\n",
    "               for generated samples.\n",
    "    -- mu2   : The sample mean over activations, precalculated on an\n",
    "               representative data set.\n",
    "    -- sigma1: The covariance matrix over activations for generated samples.\n",
    "    -- sigma2: The covariance matrix over activations, precalculated on an\n",
    "               representative data set.\n",
    "\n",
    "    Returns:\n",
    "    --   : The Frechet Distance.\n",
    "    \"\"\"\n",
    "\n",
    "    mu1 = np.atleast_1d(mu1)\n",
    "    mu2 = np.atleast_1d(mu2)\n",
    "\n",
    "    sigma1 = np.atleast_2d(sigma1)\n",
    "    sigma2 = np.atleast_2d(sigma2)\n",
    "\n",
    "    assert mu1.shape == mu2.shape, \\\n",
    "        'Training and test mean vectors have different lengths'\n",
    "    assert sigma1.shape == sigma2.shape, \\\n",
    "        'Training and test covariances have different dimensions'\n",
    "\n",
    "    diff = mu1 - mu2\n",
    "\n",
    "    # Product might be almost singular\n",
    "    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "    if not np.isfinite(covmean).all():\n",
    "        msg = ('fid calculation produces singular product; '\n",
    "               'adding %s to diagonal of cov estimates') % eps\n",
    "        print(msg)\n",
    "        offset = np.eye(sigma1.shape[0]) * eps\n",
    "        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n",
    "\n",
    "    # Numerical error might give slight imaginary component\n",
    "    if np.iscomplexobj(covmean):\n",
    "        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n",
    "            m = np.max(np.abs(covmean.imag))\n",
    "            raise ValueError('Imaginary component {}'.format(m))\n",
    "        covmean = covmean.real\n",
    "\n",
    "    tr_covmean = np.trace(covmean)\n",
    "\n",
    "    return (diff.dot(diff) + np.trace(sigma1)\n",
    "            + np.trace(sigma2) - 2 * tr_covmean)\n",
    "ori_vec_se=torch.load(path_com+\"/\"+\"Cifar-10-ori\"+\"/\"+which_model+\"/vec.pt\").numpy()\n",
    "com_mu=np.mean(ori_vec_se,axis=0)\n",
    "com_sigma=np.cov(ori_vec_se,rowvar=False)\n",
    "print(calculate_frechet_distance(com_mu,com_sigma,ori_mu,ori_sigma))\n",
    "all_sort=[]\n",
    "for path in os.listdir(path_com):\n",
    "    if(path.find(\"Cifar-10-ori\")!=-1):\n",
    "        continue\n",
    "    try:\n",
    "        com_vec=torch.load(os.path.join(path_com,path)+\"/\"+which_model+\"/vec.pt\")\n",
    "        #print(com_vec.shape)\n",
    "        while(len(com_vec.shape)>=3):\n",
    "            com_vec=com_vec.squeeze(-1)\n",
    "        com_vec=com_vec.numpy()\n",
    "        com_mu=np.mean(com_vec,axis=0)\n",
    "        com_sigma=np.cov(com_vec,rowvar=False)\n",
    "        #print(com_mu.shape,com_sigma.shape)\n",
    "        fid_value=calculate_frechet_distance(ori_mu,ori_sigma,com_mu,com_sigma)\n",
    "        #print(\"the fid value of {} is {}\".format(path,fid_value))\n",
    "        all_sort.append((fid_value,path))\n",
    "    except:\n",
    "        continue\n",
    "all_sort.sort(key=lambda x:x[0])\n",
    "for fid,path in all_sort:\n",
    "    print(fid,path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'load_img' from 'keras.preprocessing.image' (d:\\Anaconda\\lib\\site-packages\\keras\\preprocessing\\image.py)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[1;32md:\\Living_and_Study_In_University\\Research_Project\\Model_Inversion\\FID.ipynb Cell 3\u001b[0m line \u001b[0;36m<cell line: 4>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      <a href='vscode-notebook-cell:/d%3A/Living_and_Study_In_University/Research_Project/Model_Inversion/FID.ipynb#W2sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mscipy\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlinalg\u001b[39;00m \u001b[39mimport\u001b[39;00m sqrtm\n\u001b[0;32m      <a href='vscode-notebook-cell:/d%3A/Living_and_Study_In_University/Research_Project/Model_Inversion/FID.ipynb#W2sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mkeras\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mapplications\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39minception_v3\u001b[39;00m \u001b[39mimport\u001b[39;00m InceptionV3, preprocess_input\n\u001b[1;32m----> <a href='vscode-notebook-cell:/d%3A/Living_and_Study_In_University/Research_Project/Model_Inversion/FID.ipynb#W2sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mkeras\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpreprocessing\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mimage\u001b[39;00m \u001b[39mimport\u001b[39;00m load_img, img_to_array\n\u001b[0;32m      <a href='vscode-notebook-cell:/d%3A/Living_and_Study_In_University/Research_Project/Model_Inversion/FID.ipynb#W2sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mkeras\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodels\u001b[39;00m \u001b[39mimport\u001b[39;00m Model\n\u001b[0;32m      <a href='vscode-notebook-cell:/d%3A/Living_and_Study_In_University/Research_Project/Model_Inversion/FID.ipynb#W2sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39m# scale an array of images to a new size\u001b[39;00m\n",
      "\u001b[1;31mImportError\u001b[0m: cannot import name 'load_img' from 'keras.preprocessing.image' (d:\\Anaconda\\lib\\site-packages\\keras\\preprocessing\\image.py)"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from scipy.linalg import sqrtm\n",
    "from keras.applications.inception_v3 import InceptionV3, preprocess_input\n",
    "from keras.preprocessing.image import load_img, img_to_array\n",
    "from keras.models import Model\n",
    "\n",
    "# scale an array of images to a new size\n",
    "def scale_images(images, new_shape):\n",
    "    images_list = list()\n",
    "    for image in images:\n",
    "        # resize with nearest neighbor interpolation\n",
    "        new_image = resize(image, new_shape, 0)\n",
    "        # store\n",
    "        images_list.append(new_image)\n",
    "    return np.asarray(images_list)\n",
    "\n",
    "# calculate frechet inception distance\n",
    "def calculate_fid(model, images1, images2):\n",
    "    # calculate activations\n",
    "    act1 = model.predict(images1)\n",
    "    act2 = model.predict(images2)\n",
    "    # calculate mean and covariance statistics\n",
    "    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)\n",
    "    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)\n",
    "    # calculate sum squared difference between means\n",
    "    ssdiff = np.sum((mu1 - mu2)**2.0)\n",
    "    # calculate sqrt of product between cov\n",
    "    covmean = sqrtm(sigma1.dot(sigma2))\n",
    "    # check and correct imaginary numbers from sqrt\n",
    "    if np.iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    # calculate score\n",
    "    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "    return fid\n",
    "\n",
    "# prepare the inception v3 model\n",
    "model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))\n",
    "# load image 1\n",
    "image1 = load_img('image1.jpg', target_size=(299, 299))\n",
    "image1 = img_to_array(image1)\n",
    "# load image 2\n",
    "image2 = load_img('image2.jpg', target_size=(299, 299))\n",
    "image2 = img_to_array(image2)\n",
    "# convert integer to floating point values\n",
    "images1 = image1.astype('float32')\n",
    "images2 = image2.astype('float32')\n",
    "# resize images\n",
    "images1 = scale_images(images1, (299,299,3))\n",
    "images2 = scale_images(images2, (299,299,3))\n",
    "# pre-process images\n",
    "images1 = preprocess_input(images1)\n",
    "images2 = preprocess_input(images2)\n",
    "# fid between images1 and images2\n",
    "fid = calculate_fid(model, images1, images2)\n",
    "print('FID: %.3f' % fid)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
