{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d83683-fb6d-4850-b6f6-03648b6af8f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "# ===== Calculate slice wasserstein distance =====\n",
    "def sw_distance(mu1, mu2, Sigma1, Sigma2, m=100):\n",
    "    \n",
    "    SW = 0.0\n",
    "    dim = mu1.shape[0]\n",
    "    theta = torch.randn(L,dim)\n",
    "    theta = F.normalize(theta, p=2, dim=1)\n",
    "    \n",
    "    m1 = theta @ mu1\n",
    "    m2 = theta @ mu2\n",
    "    sigma1 = torch.diag(theta @ Sigma1 @ theta.T)\n",
    "    sigma2 = torch.diag(theta @ Sigma2 @ theta.T)\n",
    "    \n",
    "    W2_sq = (m1 - m2)**2 + (sigma1 - sigma2)**2\n",
    "    SW = torch.mean(W2_sq)\n",
    "    \n",
    "    return torch.sqrt(SW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75083c2c-7f65-42b0-85fe-9752f9bc9ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.linalg import solve\n",
    "# ===== Perform slice wasserstein compression, see Algorithm 2 =====\n",
    "def dkmppf_sw(model, x, kappa):\n",
    "    D = model.train_xs\n",
    "    y = model.train_ys\n",
    "    Y = list(range(D.shape[0]))  \n",
    "    Z = []  \n",
    "    \n",
    "    k_XX = model._get_Kxx_dx2()\n",
    "    k_DX = model._get_KxX_dx(x)\n",
    "    K_inv = model.get_KXX_inv()\n",
    "    \n",
    "    mu_full = k_DX @ K_inv @ y\n",
    "    Sigma_full = k_XX - k_DX @ K_inv @ k_DX.transpose(1, 2)\n",
    "    Sigma_full = Sigma_full.clamp_min(1e-9)\n",
    "    \n",
    "    \n",
    "    with torch.no_grad():\n",
    "        continue_pruning = False\n",
    "        if len(Y)>30:\n",
    "            continue_pruning = True\n",
    "        gmin = np.inf\n",
    "        while continue_pruning:\n",
    "            remains_list = [[idx for idx in range(D.shape[0]) if idx not in (Z + [Y[i]])] for i in range(len(Y))]\n",
    "            \n",
    "            K_remains_inv_list = [K_inv[np.ix_(remains, remains)] for remains in remains_list]\n",
    "            mu_removal_list = [k_DX[0][:, remains] @ K_remains_inv @ y[remains] \n",
    "                              for remains, K_remains_inv in zip(remains_list, K_remains_inv_list)]\n",
    "            Sigma_removal_list = [k_XX - k_DX[0][:, remains] @ K_remains_inv @ k_DX[0][:, remains].T\n",
    "                                for remains, K_remains_inv in zip(remains_list, K_remains_inv_list)]\n",
    "            \n",
    "            gi_list = [sw_distance(mu_removal, mu_full[0], Sigma_removal.clamp_min(1e-9), Sigma_full[0])\n",
    "                      for mu_removal, Sigma_removal in zip(mu_removal_list, Sigma_removal_list)]\n",
    "            \n",
    "            gmin, imin = min((gi, i) for i, gi in enumerate(gi_list))\n",
    "            \n",
    "            if gmin <= kappa:\n",
    "                Z.append(Y.pop(imin))\n",
    "            else:\n",
    "                continue_pruning = False\n",
    "    return Y"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
