{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 14127,
     "status": "ok",
     "timestamp": 1739480328566,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "695Ug5HmtQQ9",
    "outputId": "35f90ffb-09be-402f-890d-c52fa8f1fba9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting pot\n",
      "  Downloading POT-0.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)\n",
      "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.11/dist-packages (from pot) (1.26.4)\n",
      "Requirement already satisfied: scipy>=1.6 in /usr/local/lib/python3.11/dist-packages (from pot) (1.13.1)\n",
      "Downloading POT-0.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (897 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m897.5/897.5 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: pot\n",
      "Successfully installed pot-0.9.5\n"
     ]
    }
   ],
   "source": [
    "#!pip install pot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 177,
     "status": "ok",
     "timestamp": 1739483247997,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "97aM7TYao4WS"
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'ot'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmanifold\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MDS\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mot\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgromov\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m gromov_wasserstein, fused_gromov_wasserstein\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'ot'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.manifold import MDS\n",
    "import ot\n",
    "from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 10654,
     "status": "ok",
     "timestamp": 1739480367590,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "H-kW1iP_qY90"
   },
   "outputs": [],
   "source": [
    "base_path = \"/content/drive/MyDrive/Experiments/MNIST/\"\n",
    "\n",
    "X_train = pd.read_pickle(base_path + \"x_train_MNIST.pkl\")\n",
    "y_train = pd.read_pickle(base_path + \"y_train_MNIST.pkl\")\n",
    "X_test = pd.read_pickle(base_path + \"x_test_MNIST.pkl\")\n",
    "y_test = pd.read_pickle(base_path + \"y_test_MNIST.pkl\")\n",
    "\n",
    "X_train = np.array(X_train)\n",
    "y_train = np.array(y_train)\n",
    "X_test = np.array(X_test)\n",
    "y_test = np.array(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "executionInfo": {
     "elapsed": 122,
     "status": "ok",
     "timestamp": 1739483178783,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "Sz1Alcgdtlab"
   },
   "outputs": [],
   "source": [
    "# Use a small subset for demonstration\n",
    "subset = 50\n",
    "X_train_sub = X_train[:subset]\n",
    "y_train_sub = y_train[:subset]\n",
    "n_test = 10\n",
    "X_test_sub = X_test[:n_test]\n",
    "y_test_sub = y_test[:n_test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 177,
     "status": "ok",
     "timestamp": 1739480490202,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "IouZeX9dtmMO"
   },
   "outputs": [],
   "source": [
    "# Precompute the pixel grid and its cost matrix (all images are 28x28)\n",
    "h, w = 28, 28\n",
    "grid = np.array([[i, j] for i in range(h) for j in range(w)])\n",
    "C = ot.dist(grid, grid, metric='euclidean')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 127,
     "status": "ok",
     "timestamp": 1739480490914,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "LUw1gNSruffd"
   },
   "outputs": [],
   "source": [
    "# Define FGW distance using fused_gromov_wasserstein.\n",
    "# Here, for two images, M is the cost matrix between their pixel intensities,\n",
    "# and both structural cost matrices are taken to be C.\n",
    "def compute_fgw_distance(img1, img2, alpha=0.5):\n",
    "    p = img1.flatten().astype(np.float64)\n",
    "    q = img2.flatten().astype(np.float64)\n",
    "    p /= (p.sum() + 1e-8)\n",
    "    q /= (q.sum() + 1e-8)\n",
    "    M = ot.dist(img1.flatten()[:, None], img2.flatten()[:, None], metric='euclidean')\n",
    "    _, log = fused_gromov_wasserstein(M, C, C, p, q,\n",
    "                                       loss_fun='square_loss',\n",
    "                                       alpha=alpha,\n",
    "                                       log=True)\n",
    "    return log['fgw_dist']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 161,
     "status": "ok",
     "timestamp": 1739480514423,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "DGTDUJjYuhjM",
    "outputId": "35c09f52-156d-40fa-f547-00f16cde0cd7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D_train exists\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "n = X_train_sub.shape[0]\n",
    "D_train_path = base_path + \"D_train.npy\"\n",
    "if os.path.exists(D_train_path):\n",
    "    D_train = np.load(D_train_path)\n",
    "    print(\"D_train exists\" )\n",
    "else:\n",
    "    total = n * (n + 1) // 2\n",
    "    iteration = 0\n",
    "    D_train = np.zeros((n, n))\n",
    "    for i in range(n):\n",
    "        for j in range(i, n):\n",
    "            iteration += 1\n",
    "            print(f\"Iteration {iteration}/{total}\")\n",
    "            d = compute_fgw_distance(X_train_sub[i], X_train_sub[j])\n",
    "            D_train[i, j] = d\n",
    "            D_train[j, i] = d\n",
    "    np.save(D_train_path, D_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 124,
     "status": "ok",
     "timestamp": 1739482474850,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "HQiIkoPuwyWD",
    "outputId": "9cdfc5cb-929a-45fc-e457-275aee5de7b7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D_test exists\n"
     ]
    }
   ],
   "source": [
    "D_test_path = base_path + \"D_test.npy\"\n",
    "if os.path.exists(D_test_path):\n",
    "    D_test = np.load(D_test_path)\n",
    "    print(\"D_test exists\" )\n",
    "else:\n",
    "  total = n_test * n\n",
    "  iteration = 0\n",
    "  D_test = np.zeros((n_test, n))\n",
    "  for i in range(n_test):\n",
    "      for j in range(n):\n",
    "          iteration += 1\n",
    "          print(f\"Iteration {iteration}/{total}\")\n",
    "          D_test[i, j] = compute_fgw_distance(X_test_sub[i], X_train_sub[j])\n",
    "  np.save(D_test_path, D_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 197,
     "status": "ok",
     "timestamp": 1739482423376,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "__KPZznQOvx6"
   },
   "outputs": [],
   "source": [
    "D_test_test_path = base_path + \"D_test_test.npy\"\n",
    "if os.path.exists(D_test_test_path):\n",
    "    D_test_test = np.load(D_test_test_path)\n",
    "    print(\"D_test_test exists\" )\n",
    "else:\n",
    "  total = n_test * n_test\n",
    "  iteration = 0\n",
    "  D_test_test = np.zeros((n_test, n_test))\n",
    "  for i in range(n_test):\n",
    "    for j in range(i, n_test):\n",
    "        d = compute_fgw_distance(X_test_sub[i], X_test_sub[j])\n",
    "        D_test_test[i, j] = d\n",
    "        D_test_test[j, i] = d\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 163239,
     "status": "ok",
     "timestamp": 1739482908824,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "No3vVWF6wyQQ",
    "outputId": "4a231a16-0ed9-48b3-92ce-31dc77c8c6de"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 1/55\n",
      "Iteration 2/55\n",
      "Iteration 3/55\n",
      "Iteration 4/55\n",
      "Iteration 5/55\n",
      "Iteration 6/55\n",
      "Iteration 7/55\n",
      "Iteration 8/55\n",
      "Iteration 9/55\n",
      "Iteration 10/55\n",
      "Iteration 11/55\n",
      "Iteration 12/55\n",
      "Iteration 13/55\n",
      "Iteration 14/55\n",
      "Iteration 15/55\n",
      "Iteration 16/55\n",
      "Iteration 17/55\n",
      "Iteration 18/55\n",
      "Iteration 19/55\n",
      "Iteration 20/55\n",
      "Iteration 21/55\n",
      "Iteration 22/55\n",
      "Iteration 23/55\n",
      "Iteration 24/55\n",
      "Iteration 25/55\n",
      "Iteration 26/55\n",
      "Iteration 27/55\n",
      "Iteration 28/55\n",
      "Iteration 29/55\n",
      "Iteration 30/55\n",
      "Iteration 31/55\n",
      "Iteration 32/55\n",
      "Iteration 33/55\n",
      "Iteration 34/55\n",
      "Iteration 35/55\n",
      "Iteration 36/55\n",
      "Iteration 37/55\n",
      "Iteration 38/55\n",
      "Iteration 39/55\n",
      "Iteration 40/55\n",
      "Iteration 41/55\n",
      "Iteration 42/55\n",
      "Iteration 43/55\n",
      "Iteration 44/55\n",
      "Iteration 45/55\n",
      "Iteration 46/55\n",
      "Iteration 47/55\n",
      "Iteration 48/55\n",
      "Iteration 49/55\n",
      "Iteration 50/55\n",
      "Iteration 51/55\n",
      "Iteration 52/55\n",
      "Iteration 53/55\n",
      "Iteration 54/55\n",
      "Iteration 55/55\n"
     ]
    }
   ],
   "source": [
    "D_test_test_path = base_path + \"D_test_test.npy\"\n",
    "if os.path.exists(D_test_test_path):\n",
    "    D_test_test = np.load(D_test_test_path)\n",
    "    print(\"D_test_test exists\")\n",
    "else:\n",
    "    total = n_test * (n_test + 1) // 2\n",
    "    iteration = 0\n",
    "    D_test_test = np.zeros((n_test, n_test))\n",
    "    for i in range(n_test):\n",
    "        for j in range(i, n_test):\n",
    "            iteration += 1\n",
    "            print(f\"Iteration {iteration}/{total}\")\n",
    "            d = compute_fgw_distance(X_test_sub[i], X_test_sub[j])\n",
    "            D_test_test[i, j] = d\n",
    "            D_test_test[j, i] = d\n",
    "    np.save(D_test_test_path, D_test_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 258,
     "status": "ok",
     "timestamp": 1739483040164,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "bfTRjvvjwyBe"
   },
   "outputs": [],
   "source": [
    "D_all = np.zeros((n + n_test, n + n_test))\n",
    "D_all[:n, :n] = D_train\n",
    "D_all[:n, n:] = D_test.T\n",
    "D_all[n:, :n] = D_test\n",
    "D_all[n:, n:] = D_test_test\n",
    "\n",
    "embed_dim = 50\n",
    "mds = MDS(n_components=embed_dim, dissimilarity='precomputed', random_state=0)\n",
    "Z_all = mds.fit_transform(D_all)\n",
    "Z_train = Z_all[:n]\n",
    "Z_test = Z_all[n:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "executionInfo": {
     "elapsed": 156,
     "status": "ok",
     "timestamp": 1739483197102,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "dsyenBtlRCrT"
   },
   "outputs": [],
   "source": [
    "# plt.hist(y_train_sub, bins=range(11), align='left', rwidth=0.8)\n",
    "# plt.xlabel('Class')\n",
    "# plt.ylabel('Frequency')\n",
    "# plt.title('Distribution of y_train')\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "executionInfo": {
     "elapsed": 129,
     "status": "ok",
     "timestamp": 1739483198936,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "d2t6s2NzRgHu"
   },
   "outputs": [],
   "source": [
    "# plt.hist(y_test_sub, bins=range(11), align='left', rwidth=0.8)\n",
    "# plt.xlabel('Class')\n",
    "# plt.ylabel('Frequency')\n",
    "# plt.title('Distribution of y_train')\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "executionInfo": {
     "elapsed": 154,
     "status": "ok",
     "timestamp": 1739483211913,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "A1Nzyd_SRo5m"
   },
   "outputs": [],
   "source": [
    "def rbf_kernel_scalar(x1, x2, l=1.0, sigma_f=1.0):\n",
    "    return sigma_f**2 * np.exp(-np.sum((x1 - x2)**2) / (2 * l**2))\n",
    "\n",
    "def local_gp_predict(x_star, X_train, y_train, k_neighbors=10, l=1.0, sigma_f=1.0, sigma_n=1e-6):\n",
    "    dists = np.linalg.norm(X_train - x_star, axis=1)\n",
    "    idx = np.argsort(dists)[:k_neighbors]\n",
    "    X_local = X_train[idx]\n",
    "    y_local = y_train[idx]\n",
    "    K = np.array([[rbf_kernel_scalar(xi, xj, l, sigma_f) for xj in X_local] for xi in X_local])\n",
    "    K += sigma_n * np.eye(k_neighbors)\n",
    "    k_star = np.array([rbf_kernel_scalar(x_star, xi, l, sigma_f) for xi in X_local])\n",
    "    K_inv = np.linalg.inv(K)\n",
    "    f_mean = k_star.dot(K_inv).dot(y_local)\n",
    "    k_star_star = rbf_kernel_scalar(x_star, x_star, l, sigma_f)\n",
    "    f_var = k_star_star - k_star.dot(K_inv).dot(k_star)\n",
    "    return f_mean, f_var, idx\n",
    "\n",
    "def crps_gaussian(y, m, s):\n",
    "    if s < 1e-8:\n",
    "        return np.abs(m - y)\n",
    "    z = (y - m) / s\n",
    "    return s * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1/np.sqrt(np.pi))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "executionInfo": {
     "elapsed": 243,
     "status": "ok",
     "timestamp": 1739483397435,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "29h5Ar5nR2dF"
   },
   "outputs": [],
   "source": [
    "k_neighbors = 10\n",
    "l = 1.0\n",
    "sigma_f = 1.0\n",
    "sigma_n = 1e-6\n",
    "f_mean_local = np.zeros(n_test)\n",
    "pred_var_local = np.zeros(n_test)\n",
    "neighbor_indices = []\n",
    "\n",
    "for i in range(n_test):\n",
    "    m, var, idx = local_gp_predict(Z_test[i], Z_train, y_train_sub, k_neighbors, l, sigma_f, sigma_n)\n",
    "    f_mean_local[i] = m\n",
    "    pred_var_local[i] = var\n",
    "    neighbor_indices.append(idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 261,
     "status": "ok",
     "timestamp": 1739483397694,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "2eVL5bRDR6Ja",
    "outputId": "843fec2b-6a41-4c40-ace1-dc1d3dfcf323"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Nearest Neighbors GP RMSE: 3.4540\n",
      "Nearest Neighbors GP Mean CRPS: 2.1671\n"
     ]
    }
   ],
   "source": [
    "pred_std_local = np.sqrt(np.maximum(pred_var_local, 1e-8))\n",
    "rmse_local = np.sqrt(np.mean((f_mean_local - y_test_sub)**2))\n",
    "print(\"Nearest Neighbors GP RMSE: {:.4f}\".format(rmse_local))\n",
    "\n",
    "crps_scores_local = np.array([crps_gaussian(y_test_sub[i], f_mean_local[i], pred_std_local[i]) for i in range(n_test)])\n",
    "mean_crps_local = np.mean(crps_scores_local)\n",
    "print(\"Nearest Neighbors GP Mean CRPS: {:.4f}\".format(mean_crps_local))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1739483397948,
     "user": {
      "displayName": "Vardaan Tekriwal",
      "userId": "17776599370791499308"
     },
     "user_tz": 480
    },
    "id": "XjmNtq-WR82j",
    "outputId": "dfe96873-fdf3-4a12-a9c5-458829b44e43"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Accuracy: 0.4000\n",
      "Classification Report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.50      1.00      0.67         1\n",
      "           1       1.00      1.00      1.00         2\n",
      "           2       0.00      0.00      0.00         0\n",
      "           3       0.00      0.00      0.00         1\n",
      "           4       0.00      0.00      0.00         0\n",
      "           5       0.00      0.00      0.00         0\n",
      "           6       1.00      1.00      1.00         1\n",
      "           7       0.00      0.00      0.00         2\n",
      "           8       0.00      0.00      0.00         3\n",
      "\n",
      "    accuracy                           0.40        10\n",
      "   macro avg       0.28      0.33      0.30        10\n",
      "weighted avg       0.35      0.40      0.37        10\n",
      "\n",
      "Confusion Matrix:\n",
      "[[1 0 0 0 0 0 0 0 0]\n",
      " [0 2 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 1 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 1 0 0]\n",
      " [1 0 0 1 0 0 0 0 0]\n",
      " [0 0 1 0 1 0 0 1 0]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "\n",
    "pred_labels = np.round(f_mean_local).astype(int)\n",
    "accuracy = np.mean(pred_labels == y_test_sub)\n",
    "print(\"Classification Accuracy: {:.4f}\".format(accuracy))\n",
    "print(\"Classification Report:\")\n",
    "print(classification_report(y_test_sub, pred_labels))\n",
    "print(\"Confusion Matrix:\")\n",
    "print(confusion_matrix(y_test_sub, pred_labels))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BzVDOH1ISJED"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyO0XVXvFOdaCdd0Cqon0GEp",
   "mount_file_id": "1qjKOsOASGRUHi1x6S9eIZwIshV8wjGLp",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "MNISTenv",
   "language": "python",
   "name": "mnistenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
