{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7dac0241-8c3f-41f7-969d-f06540000c66",
   "metadata": {},
   "source": [
    "# Perceptual evaluation based on GMMs in the feature space\n",
    "\n",
    "We replace the MW metric used in \n",
    "[Luzi et al., 2023](https://openaccess.thecvf.com/content/WACV2023/papers/Luzi_Evaluating_Generative_Networks_Using_Gaussian_Mixtures_of_Image_Features_WACV_2023_paper.pdf)\n",
    "with our DSMW/MSW metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d7b6ab80-3dde-4c1b-8c99-d21e171323fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import os\n",
    "from torchvision.models import inception_v3\n",
    "from scipy.linalg import sqrtm\n",
    "import GMM_utils as GMM\n",
    "import sliced_mw as SMW\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "import ot\n",
    "import scipy.stats as sps\n",
    "import scipy.linalg as spl\n",
    "from scipy.optimize import linprog\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "plt.rcParams.update({'font.size': 22})\n",
    "\n",
    "reg_cov = 1e-2\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "def GaussianW2(m0,m1,Sigma0,Sigma1):\n",
    "    # Wasserstein between Gaussians\n",
    "    # source: https://github.com/judelo\n",
    "    Sigma00  = spl.sqrtm(Sigma0)\n",
    "    Sigma010 = spl.sqrtm(Sigma00@Sigma1@Sigma00)\n",
    "    d = np.linalg.norm(m0-m1)**2\n",
    "    d =+np.trace(Sigma0+Sigma1-2*Sigma010)\n",
    "    return d\n",
    "\n",
    "def MW2(pi_0,pi_1,mu_0,mu_1,Sigma0_arr,Sigma1_arr):\n",
    "    # Return the MW dist\n",
    "    # source: https://github.com/judelo\n",
    "    K0 = mu_0.shape[0]\n",
    "    K1 = mu_1.shape[0]\n",
    "    d  = mu_0.shape[1]\n",
    "    Sigma0_arr = Sigma0_arr.reshape(K0,d,d)\n",
    "    Sigma1_arr = Sigma1_arr.reshape(K1,d,d)\n",
    "    M  = np.zeros((K0,K1))\n",
    "    \n",
    "    # Pairwise Wasserstein distance matrix between all Gaussians\n",
    "    for k in range(K0):\n",
    "        for l in range(K1):\n",
    "            M[k,l]  = GaussianW2(mu_0[k,:],mu_1[l,:],Sigma0_arr[k,:,:],Sigma1_arr[l,:,:])\n",
    "    # Compute OT distance\n",
    "    wstar     = ot.emd(pi_0,pi_1,M)      \n",
    "    dist   = np.sum(wstar*M)\n",
    "    return dist\n",
    "\n",
    "def calc_MW_org(gmm1, gmm2):\n",
    "    return MW2(gmm1.weights.numpy(), gmm2.weights.numpy(),\n",
    "                        gmm1.means.numpy(), gmm2.means.numpy(),\n",
    "                        gmm1.covariances.numpy(), gmm2.covariances.numpy())\n",
    "\n",
    "def add_noise(images, noise_level):\n",
    "    # Gaussian noise\n",
    "    noise = torch.randn_like(images) * noise_level\n",
    "    noisy_images = torch.clamp(images + noise, 0, 1)  # Keep values in valid range\n",
    "    return noisy_images\n",
    "\n",
    "def add_salt_and_pepper_noise(images, noise_level):\n",
    "    \"\"\"\n",
    "    Applies salt and pepper noise to a batch of images.\n",
    "    \n",
    "    Args:\n",
    "        images (torch.Tensor): Tensor of shape (N, C, H, W).\n",
    "        noise_level (float): Fraction of pixels to alter.\n",
    "        \n",
    "    Returns:\n",
    "        torch.Tensor: Noisy images.\n",
    "    \"\"\"\n",
    "    rand_tensor = torch.rand_like(images)\n",
    "    noisy_images = images.clone()\n",
    "    noisy_images[rand_tensor < (noise_level / 2)] = 0.0  # Pepper\n",
    "    noisy_images[rand_tensor > 1 - (noise_level / 2)] = 1.0  # Salt\n",
    "    return noisy_images\n",
    "\n",
    "def apply_gaussian_blur(images, sigma):\n",
    "    \"\"\"\n",
    "    Applies Gaussian blur to a batch of images.\n",
    "    \n",
    "    Args:\n",
    "        images (torch.Tensor): Tensor of shape (N, C, H, W).\n",
    "        sigma (float): Standard deviation for Gaussian kernel.\n",
    "    \n",
    "    Returns:\n",
    "        torch.Tensor: Blurred images.\n",
    "    \"\"\"\n",
    "    gaussian_blur = transforms.GaussianBlur(kernel_size=3, sigma=sigma)\n",
    "    # Since GaussianBlur may not support batched tensors directly,\n",
    "    # apply it to each image individually.\n",
    "    blurred_images = torch.stack([gaussian_blur(img) for img in images])\n",
    "    return blurred_images\n",
    "\n",
    "def apply_distortion(images, distortion_type, distortion_level):\n",
    "    \"\"\"\n",
    "    Applies the selected distortion to the images.\n",
    "    \n",
    "    Args:\n",
    "        images (torch.Tensor): Tensor of shape (N, C, H, W).\n",
    "        distortion_type (str): Type of distortion (\"Gaussian\", \"SP\", or \"Blur\").\n",
    "        distortion_level (float): Level of distortion to apply.\n",
    "        \n",
    "    Returns:\n",
    "        torch.Tensor: Distorted images.\n",
    "    \"\"\"\n",
    "    dt = distortion_type\n",
    "    if dt == \"Gaussian\":\n",
    "        return add_noise(images, distortion_level)\n",
    "    elif dt == \"SP\":\n",
    "        return add_salt_and_pepper_noise(images, distortion_level)\n",
    "    elif dt == \"Blur\":\n",
    "        return apply_gaussian_blur(images, distortion_level)\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown distortion type: {distortion_type}\")\n",
    "\n",
    "def compute_embeddings_batchwise(images, model, device, batch_size=50):\n",
    "    \"\"\"\n",
    "    Computes embeddings in batches to avoid memory issues.\n",
    "    \n",
    "    Args:\n",
    "        images (torch.Tensor): Tensor of shape (N, C, H, W).\n",
    "        model (torch.nn.Module): Inception model without the final fc layer.\n",
    "        device (str): 'cuda' or 'cpu'.\n",
    "        batch_size (int): Number of images per batch.\n",
    "        \n",
    "    Returns:\n",
    "        np.ndarray: Embeddings of shape (N, embedding_dim).\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    resizer = transforms.Resize(299)\n",
    "    embeddings_list = []\n",
    "    num_images = images.shape[0]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(0, num_images, batch_size):\n",
    "            batch = images[i:i+batch_size].to(device)\n",
    "            # Resize each image to 299x299 (expected by Inception v3)\n",
    "            batch = resizer(batch)\n",
    "            emb = model(batch).cpu().numpy()\n",
    "            embeddings_list.append(emb)\n",
    "    return np.concatenate(embeddings_list, axis=0)\n",
    "\n",
    "def calculate_fid_from_embeddings(embeddings1, embeddings2):\n",
    "    \"\"\"\n",
    "    Calculate the Frechet Inception Distance (FID) between two sets of embeddings.\n",
    "    \n",
    "    Args:\n",
    "        embeddings1 (np.ndarray): Embeddings of set 1, shape (N, D).\n",
    "        embeddings2 (np.ndarray): Embeddings of set 2, shape (N, D).\n",
    "        \n",
    "    Returns:\n",
    "        float: The FID score.\n",
    "    \"\"\"\n",
    "    mu1 = np.mean(embeddings1, axis=0)\n",
    "    mu2 = np.mean(embeddings2, axis=0)\n",
    "    sigma1 = np.cov(embeddings1, rowvar=False)\n",
    "    sigma2 = np.cov(embeddings2, rowvar=False)\n",
    "    \n",
    "    # Add a small identity matrix to prevent numerical errors\n",
    "    epsilon = reg_cov\n",
    "    sigma1 += epsilon * np.eye(sigma1.shape[0])\n",
    "    sigma2 += epsilon * np.eye(sigma2.shape[0])\n",
    "    \n",
    "    diff = mu1 - mu2\n",
    "    diff_squared = diff.dot(diff)\n",
    "    \n",
    "    # Compute the square root of the product of covariance matrices\n",
    "    covmean = sqrtm(sigma1.dot(sigma2))\n",
    "    if np.iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    \n",
    "    fid = diff_squared + np.trace(sigma1 + sigma2 - 2 * covmean)\n",
    "    return fid\n",
    "\n",
    "def compute_fid(real_images, noise_levels, embeddings_folder=\"embeddings\", batch_size=50, \n",
    "                type=\"FID\", distortion_type=\"Gaussian\", save=True):\n",
    "    \"\"\"\n",
    "    Computes FID scores between real images and distorted versions by\n",
    "    extracting Inception embeddings batchwise and saving them.\n",
    "    \n",
    "    Args:\n",
    "        real_images (torch.Tensor): Tensor of shape (N, C, H, W) for real images.\n",
    "        noise_levels (list): List of distortion levels to add.\n",
    "        embeddings_folder (str): Folder to save embeddings.\n",
    "        batch_size (int): Batch size for embedding extraction.\n",
    "        type (str): Either \"FID\" or \"GMM_SMWS\".\n",
    "        distortion_type (str): Type of distortion (\"Gaussian\", \"SP\", or \"Blur\").\n",
    "        \n",
    "    Returns:\n",
    "        dict: Mapping noise_level -> FID score.\n",
    "    \"\"\"\n",
    "    os.makedirs(embeddings_folder, exist_ok=True)\n",
    "    plots_folder = \"perception_plots\"\n",
    "    os.makedirs(plots_folder, exist_ok=True)\n",
    "    fid_scores = {}\n",
    "    dt = distortion_type\n",
    "\n",
    "    device = \"cuda\" if (torch.cuda.is_available() and type == \"GMM_SMSW\" and type != \"GMM_MW\") else \"cpu\"\n",
    "    inception = inception_v3(pretrained=True, transform_input=True).to(device)\n",
    "    # Remove the classification head by replacing it with an identity function\n",
    "    inception.fc = torch.nn.Identity()\n",
    "    \n",
    "    # Compute embeddings for real images batchwise and save them.\n",
    "    savefile = \"real_embeddings.npy\"\n",
    "    savefile = os.path.join(embeddings_folder, savefile)\n",
    "    try:\n",
    "        real_embeddings = np.load(savefile)\n",
    "        print(\"Loaded Real Embedding\")\n",
    "    except:\n",
    "        real_embeddings = compute_embeddings_batchwise(real_images, inception, device, batch_size=batch_size)\n",
    "        np.save(savefile, real_embeddings)\n",
    "    \n",
    "    if \"GMM\" in type:\n",
    "        torch_real_embeddings = torch.tensor(real_embeddings).to(device)\n",
    "        real_gmm = GMM.FittedGaussianMixtureModel(torch_real_embeddings, 5, device=device, reg_cov=reg_cov)\n",
    "    \n",
    "    for noise_level in noise_levels:\n",
    "        savefile_raw = f\"fake_embeddings_{dt}_{noise_level}.npy\"\n",
    "        savefile = os.path.join(embeddings_folder, savefile_raw)\n",
    "        try:\n",
    "            fake_embeddings = np.load(savefile)\n",
    "            print(\"Loaded \", savefile_raw)\n",
    "        except:\n",
    "            distorted_images = apply_distortion(real_images, distortion_type, noise_level)\n",
    "            fake_embeddings = compute_embeddings_batchwise(distorted_images, inception, device, batch_size=batch_size)\n",
    "            np.save(savefile, fake_embeddings)\n",
    "        if \"GMM\" in type:\n",
    "            torch_fake_embeddings = torch.tensor(fake_embeddings).to(device)\n",
    "            fake_gmm = GMM.FittedGaussianMixtureModel(torch_fake_embeddings, 10, device=device)\n",
    "        \n",
    "\n",
    "        pnum = 10000\n",
    "        t0 = time.time()\n",
    "        if type == \"FID\":\n",
    "            fid = calculate_fid_from_embeddings(real_embeddings, fake_embeddings)\n",
    "        elif type == \"GMM_SMSW\":\n",
    "            fid = SMW.calc_SMSW(fake_gmm, real_gmm, pnum=pnum).item()\n",
    "        elif type == \"GMM_MSW\":\n",
    "            fid = SMW.calc_MSW(fake_gmm, real_gmm, pnum=pnum).item()\n",
    "        elif type == \"GMM_MW\":\n",
    "            fid = np.array(calc_MW_org(fake_gmm, real_gmm))\n",
    "        elif type == \"GMM_SMW\":\n",
    "            fid = SMW.calc_test_SMW(fake_gmm, real_gmm, pnum=pnum).item()\n",
    "        fid_scores[noise_level] = fid\n",
    "        print(f\"{type} for {distortion_type} distortion at level {noise_level}: {fid}\")\n",
    "        print(\"Seconds: \", t0 - time.time())\n",
    "        del fake_embeddings\n",
    "        torch.cuda.empty_cache()\n",
    "    \n",
    "    # Plot the FID-type values vs distortion level\n",
    "    levels = sorted(fid_scores.keys())\n",
    "    scores = [fid_scores[lvl] for lvl in levels]\n",
    "    if save:\n",
    "        plt.figure(figsize=(8,6))\n",
    "        plt.plot(levels, scores, marker='o')\n",
    "        plt.xlabel(\"Distortion Level\")\n",
    "        plt.grid(True)\n",
    "        plot_filename = os.path.join(plots_folder, f\"{type}_{distortion_type}_plot.png\")\n",
    "        plt.savefig(plot_filename)\n",
    "        plt.show()\n",
    "        print(f\"Plot saved to {plot_filename}\")\n",
    "\n",
    "    return fid_scores\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Load CIFAR-10 images\n",
    "    transform = transforms.Compose([transforms.ToTensor()])\n",
    "    dataset = torchvision.datasets.CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\n",
    "    real_images = torch.stack([dataset[i][0] for i in range(1000)])  # Adjust number of images as needed\n",
    "\n",
    "    noise_levels = np.linspace(.5, 1.5, 10)\n",
    "    dt = \"Blur\"\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_MW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_MSW\", distortion_type=dt)\n",
    "    ffid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_SMSW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"FID\", distortion_type=dt)\n",
    "\n",
    "\n",
    "\n",
    "    noise_levels = np.linspace(.05, .3, 10)\n",
    "    dt = \"SP\"\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_MW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_MSW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_SMSW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"FID\", distortion_type=dt)\n",
    "\n",
    "    noise_levels = np.linspace(.01, .2, 10)\n",
    "    dt = \"Gaussian\"\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_MW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_MSW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"GMM_SMSW\", distortion_type=dt)\n",
    "    fid_scores = compute_fid(real_images, noise_levels, batch_size=10, type=\"FID\", distortion_type=dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade66d06-3838-42ec-8504-0a8e9f6e94bd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
