{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 21229,
     "status": "ok",
     "timestamp": 1719907365642,
     "user": {
      "displayName": "Xinxian Ma",
      "userId": "07518886605624739287"
     },
     "user_tz": -480
    },
    "id": "3saIzvMaKP29"
   },
   "outputs": [],
   "source": [
    "from sklearn.manifold import TSNE\n",
    "import umap\n",
    "from torchvision import datasets, transforms\n",
    "from sklearn.datasets import load_iris\n",
    "from numpy import reshape\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors\n",
    "from sklearn.metrics import normalized_mutual_info_score, silhouette_score\n",
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "HRYct6XSDxzC"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Preparation for similarity method\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2UA7pCT4maUj"
   },
   "source": [
    "# MNIST dataset TSNE fitting and visualizing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data prepocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 17772,
     "status": "ok",
     "timestamp": 1719907414636,
     "user": {
      "displayName": "Xinxian Ma",
      "userId": "07518886605624739287"
     },
     "user_tz": -480
    },
    "id": "B2JR6Pe-maUj",
    "outputId": "27c8bb8b-b1ea-4540-a83d-2fcdd93839d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: torch.Size([40000, 784])\n",
      "y_train shape: torch.Size([40000])\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "\n",
    "mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "n=40000\n",
    "\n",
    "x_train = torch.cat((mnist_train.data.float(), mnist_test.data.float()),0).to(torch.float32)[:n,]\n",
    "y_train = torch.cat((mnist_train.targets, mnist_test.targets),0).to(torch.float32)[:n,]\n",
    "\n",
    "del mnist_train\n",
    "del mnist_test\n",
    "\n",
    "x_train = x_train.view(x_train.shape[0], -1)\n",
    "\n",
    "print(f\"x_train shape: {x_train.shape}\")\n",
    "print(f\"y_train shape: {y_train.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([784, 40000])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_new = x_train.T\n",
    "del x_train\n",
    "print(x_new.shape)\n",
    "type(x_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Function defined "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nqy8rAPOmaUj"
   },
   "source": [
    "#### Y Approximation\n",
    "*   n = 70000\n",
    "*   k = 10\n",
    "*   ny = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "executionInfo": {
     "elapsed": 29379,
     "status": "ok",
     "timestamp": 1719907453271,
     "user": {
      "displayName": "Xinxian Ma",
      "userId": "07518886605624739287"
     },
     "user_tz": -480
    },
    "id": "ssz4X4jRmaUj",
    "jupyter": {
     "outputs_hidden": true
    },
    "outputId": "1b9c1313-d743-41cf-9be5-4ba19309b2b7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "def gaussian_kernel(x, y, sigma=1000.0):\n",
    "    beta = 1.0 / (2.0 * sigma ** 2)\n",
    "    dist = torch.cdist(x.t(), y.t()) ** 2\n",
    "    # print(dist.mean())\n",
    "    kernel = torch.exp(-beta * dist)\n",
    "    return kernel\n",
    "\n",
    "def mmd_loss(x, y, sigma_list=[10.0, 20.0, 30.0, 50.0]):\n",
    "\n",
    "    mmd_list = []\n",
    "    for sigma in sigma_list:\n",
    "      xx_kernel = gaussian_kernel(x, x, sigma)\n",
    "      yy_kernel = gaussian_kernel(y, y, sigma)\n",
    "      xy_kernel = gaussian_kernel(x, y, sigma)\n",
    "      mmd = xx_kernel.mean() + yy_kernel.mean() - 2 * xy_kernel.mean()\n",
    "      mmd_list.append(mmd)\n",
    "    return max(mmd_list)\n",
    "\n",
    "# 假设我们有 k 个数据集 X1 到 Xk\n",
    "k = 10\n",
    "n = x_new.shape[1]//k\n",
    "m = x_new.shape[0]\n",
    "ny = 500\n",
    "batch_size = 256\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "x_tensor = x_new/255\n",
    "X = [x_tensor[:, i*n:(i+1)*n].to(device) for i in range(k)]\n",
    "\n",
    "def approximate_Y(beta = 0.0):\n",
    "  \n",
    "  Y = torch.rand((m, ny), dtype=torch.float32, requires_grad=True, device=device)\n",
    "  # print(Y)\n",
    "  learning_rate = 1e-1\n",
    "  # optimizer = torch.optim.SGD([Y], lr=learning_rate)\n",
    "  optimizer = torch.optim.Adam([Y], lr=learning_rate)\n",
    "  # optimizer = torch.optim.RMSprop([Y], lr=learning_rate)\n",
    "\n",
    "  # 训练循环\n",
    "  num_epochs = 500\n",
    "  for epoch in range(num_epochs):\n",
    "\n",
    "      total_loss = 0.0\n",
    "      optimizer.zero_grad()\n",
    "      total_grad = None\n",
    "\n",
    "      for i in range(k):\n",
    "          Xi = X[i]\n",
    "\n",
    "          # mini_batch\n",
    "          dataset = TensorDataset(Xi.t())\n",
    "          data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "          for batch in data_loader:\n",
    "            Xi_batch = batch[0].t()\n",
    "            loss = mmd_loss(Xi_batch, Y)\n",
    "            loss.backward()\n",
    "            total_loss += loss.item()\n",
    "\n",
    "            gradients = Y.grad\n",
    "            std_grad = torch.std(gradients)\n",
    "            noise = torch.randn_like(gradients) * std_grad * beta\n",
    "            noisy_gradients = gradients + noise\n",
    "\n",
    "            if total_grad is None:\n",
    "              total_grad = noisy_gradients\n",
    "            else:\n",
    "              total_grad += noisy_gradients\n",
    "            break # only use first mini-batch\n",
    "\n",
    "          Y.grad.zero_()\n",
    "\n",
    "      avg_grad = total_grad / k\n",
    "      total_loss = total_loss / k\n",
    "\n",
    "      with torch.no_grad():\n",
    "          Y.grad = avg_grad\n",
    "          Y.data=torch.clamp(Y.data, min=0, max=1)\n",
    "          optimizer.step()\n",
    "      if epoch== num_epochs-1:\n",
    "          print('Epoch finished')\n",
    "  return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 396
    },
    "executionInfo": {
     "elapsed": 2910,
     "status": "ok",
     "timestamp": 1719907480591,
     "user": {
      "displayName": "Xinxian Ma",
      "userId": "07518886605624739287"
     },
     "user_tz": -480
    },
    "id": "NXoMJxjg9VEu",
    "outputId": "e385c7ff-7eb7-4be2-f07e-eb96709398e8"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def visualize_Y(Y):\n",
    "\n",
    "    # figure\n",
    "    for i in range(24):\n",
    "        vector = Y[:, torch.randint(0,ny,(1,1))].cpu().detach().numpy()\n",
    "        image = vector.reshape(28, 28)\n",
    "        plt.subplot(3, 8, i+1)\n",
    "        plt.imshow(image, cmap='gray', vmin=0, vmax=1)\n",
    "        plt.title('28x28 Image')\n",
    "    plt.show()\n",
    "\n",
    "    # Y tsne visualization\n",
    "    Y_numpy = Y.detach().cpu().numpy()\n",
    "    # tsne = TSNE(n_components=2, verbose=1)\n",
    "    # z = tsne.fit_transform(Y_numpy.T)\n",
    "    reducer = umap.UMAP(n_components=2)\n",
    "    z = reducer.fit_transform(Y_numpy.T)\n",
    "\n",
    "    df = pd.DataFrame()\n",
    "    df[\"y\"] = 1\n",
    "    df[\"comp-1\"] = z[:,0]\n",
    "    df[\"comp-2\"] = z[:,1]\n",
    "\n",
    "    sns.scatterplot(x=\"comp-1\", y=\"comp-2\",\n",
    "                    data=df).set(title=\"Y UMAP projection\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ghEXehjTOLkF"
   },
   "source": [
    "#### TSNE result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_overlap_ratio(high_dim_data, low_dim_data, k):\n",
    "    \"\"\"\n",
    "    high_dim_data: torch data (40000*784)\n",
    "    low_dim_data: z(40000*2)\n",
    "    \"\"\"\n",
    "    high_dim_data = high_dim_data.cpu().detach().numpy()\n",
    "\n",
    "    # 在高维空间中找到每个点的 k 近邻\n",
    "    nn_high = NearestNeighbors(n_neighbors=k+1).fit(high_dim_data)\n",
    "    distances_high, indices_high = nn_high.kneighbors(high_dim_data)\n",
    "    \n",
    "    # 在低维空间中找到每个点的 k 近邻\n",
    "    nn_low = NearestNeighbors(n_neighbors=k+1).fit(low_dim_data)\n",
    "    distances_low, indices_low = nn_low.kneighbors(low_dim_data)\n",
    "    \n",
    "    overlap_ratios = []\n",
    "    \n",
    "    # 对于每个点，计算两个集合的重叠比例\n",
    "    for i in range(len(high_dim_data)):\n",
    "        A = set(indices_high[i][1:])  # 跳过自身\n",
    "        B = set(indices_low[i][1:])   # 跳过自身\n",
    "        overlap = len(A.intersection(B)) / k\n",
    "        overlap_ratios.append(overlap)\n",
    "    \n",
    "    return np.mean(overlap_ratios)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 525,
     "status": "ok",
     "timestamp": 1719907490501,
     "user": {
      "displayName": "Xinxian Ma",
      "userId": "07518886605624739287"
     },
     "user_tz": -480
    },
    "id": "xxv9kXUJOnZl"
   },
   "outputs": [],
   "source": [
    "def tsne_res(high_dim_data, distance, y_train=y_train, visualize=False):\n",
    "  tsne = TSNE(n_components=2, verbose=0, metric=\"precomputed\", init='random')\n",
    "  z = tsne.fit_transform(distance)\n",
    "  # reducer = umap.UMAP(n_components=2, metric='precomputed', init='random')\n",
    "  # z = reducer.fit_transform(distance)\n",
    "\n",
    "  if visualize == True:\n",
    "    df = pd.DataFrame()\n",
    "    df[\"y\"] = y_train\n",
    "    df[\"comp-1\"] = z[:,0]\n",
    "    df[\"comp-2\"] = z[:,1]\n",
    "\n",
    "    sns.scatterplot(x=\"comp-1\", y=\"comp-2\", hue=df.y.tolist(),\n",
    "                    palette=sns.color_palette(\"hls\", 10),\n",
    "                    data=df).set(title=\"MNIST data UMAP projection\")\n",
    "\n",
    "#-------------------------------------------------------------------------\n",
    "  X_tr, X_te, y_tr, y_te = train_test_split(z, y_train, test_size=0.3)\n",
    "  k_values = [1, 10, 50]\n",
    "\n",
    "  knn_accuracies = {}\n",
    "\n",
    "  result = [0, 0, 0, 0, 0, 0, 0, 0]\n",
    "  # [1-nn, 10-nn, 50-nn, NMI, SC, overlap_1nn, overlap_10nn, overlap_50nn]\n",
    "\n",
    "  for i in range(0, len(k_values)):\n",
    "      # k-NN分类器\n",
    "      k = k_values[i]\n",
    "      knn = KNeighborsClassifier(n_neighbors=k)\n",
    "      knn.fit(X_tr, y_tr)\n",
    "      accuracy = knn.score(X_te, y_te)\n",
    "      result[i] = accuracy\n",
    "\n",
    "  # for key,value in knn_accuracies.items():\n",
    "  #     print(f'{key}-nn:/ {value}')\n",
    "\n",
    "  # 使用kmeans聚类\n",
    "  kmeans = KMeans(n_clusters=10, n_init=10)\n",
    "  y_pred = kmeans.fit_predict(z)\n",
    "\n",
    "  # 计算NMI\n",
    "  nmi = normalized_mutual_info_score(y_train, y_pred)\n",
    "  result[3] = nmi\n",
    "  # print('NMI:', nmi)\n",
    "\n",
    "  # 计算轮廓系数\n",
    "  sc = silhouette_score(z, y_pred)\n",
    "  result[4] = sc\n",
    "  # print('SC:', sc)\n",
    "\n",
    "  # 计算overlap\n",
    "  result[5] = calculate_overlap_ratio(high_dim_data, z, k=1)\n",
    "  result[6] = calculate_overlap_ratio(high_dim_data, z, k=10)\n",
    "  result[7] = calculate_overlap_ratio(high_dim_data, z, k=50)\n",
    "\n",
    "  return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KLwr9kn5ne4x"
   },
   "source": [
    "#### approxiamted TSNE\n",
    "- estimate x dist from dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "executionInfo": {
     "elapsed": 553,
     "status": "ok",
     "timestamp": 1719907994253,
     "user": {
      "displayName": "Xinxian Ma",
      "userId": "07518886605624739287"
     },
     "user_tz": -480
    },
    "id": "22nphApOlFXw"
   },
   "outputs": [],
   "source": [
    "def estimate_x_dist(shuffled_X_cat, Y):\n",
    "  '''\n",
    "  shuffled_X_cat\n",
    "  Y: learned dist\n",
    "  '''\n",
    "  B = torch.cdist(Y.t(), shuffled_X_cat.t())\n",
    "  A = torch.cdist(Y.t(), Y.t())\n",
    "  # A = A + torch.eye(A.shape[0], device = device)*1e-5\n",
    "  # B = Y.t() @ X_cat\n",
    "  # A = Y.t() @ Y\n",
    "  res = B.t() @ torch.linalg.inv(A) @ B\n",
    "  res[torch.arange(res.size(0)), torch.arange(res.size(1))] = 0\n",
    "  return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cat = torch.cat(X, dim=1)\n",
    "#-----------------\n",
    "cols = X_cat.size(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### correct TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<correct TSNE>\n",
      "------------------\n",
      "<mean>\n",
      "1-nn:/ 0.9617749999999999\n",
      "10-nn:/ 0.96555\n",
      "50-nn:/ 0.9608916666666666\n",
      "NMI: 0.7747370408919848\n",
      "SC: 0.4226216286420822\n",
      "overlap_1nn: 0.417605\n",
      "overlap_10nn 0.39053875\n",
      "overlap_50nn 0.34405215000000006\n",
      "<std>\n",
      "1-nn:/ 0.0015152236725242054\n",
      "10-nn:/ 0.0017184456801294587\n",
      "50-nn:/ 0.001488404253629464\n",
      "NMI: 0.024343042071764465\n",
      "SC: 0.008243462045052514\n",
      "overlap_1nn: 0.0015824348327814373\n",
      "overlap_10nn 0.0005110140042112281\n",
      "overlap_50nn 0.0006676934195422084\n"
     ]
    }
   ],
   "source": [
    "# correct tsne\n",
    "print('<correct TSNE>')\n",
    "print('------------------')\n",
    "\n",
    "\n",
    "n = 10\n",
    "result = []\n",
    "for i in range(0, n):\n",
    "\n",
    "    perm = torch.randperm(cols)\n",
    "    shuffled_X_cat = X_cat[:, perm]\n",
    "\n",
    "    distance = torch.cdist(shuffled_X_cat.t(), shuffled_X_cat.t()).detach().cpu().numpy()\n",
    "    result.append(tsne_res(shuffled_X_cat.t(), distance, y_train[perm]))\n",
    "\n",
    "final_mean = np.mean(result, 0)\n",
    "final_var = np.var(result, 0)**(1/2)\n",
    "\n",
    "k_values = [1, 10, 50]\n",
    "\n",
    "# mean\n",
    "print('<mean>')\n",
    "for i in range(0, 3):\n",
    "    print(f'{k_values[i]}-nn:/ {final_mean[i]}')\n",
    "print('NMI:', final_mean[3])\n",
    "print('SC:', final_mean[4])\n",
    "print('overlap_1nn:', final_mean[5])\n",
    "print('overlap_10nn', final_mean[6])\n",
    "print('overlap_50nn', final_mean[7])\n",
    "\n",
    "# std\n",
    "print('<std>')\n",
    "for i in range(0, 3):\n",
    "    print(f'{k_values[i]}-nn:/ {final_var[i]}')\n",
    "print('NMI:', final_var[3])\n",
    "print('SC:', final_var[4])\n",
    "print('overlap_1nn:', final_var[5])\n",
    "print('overlap_10nn', final_var[6])\n",
    "print('overlap_50nn', final_var[7])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### approximate with beta = 0.0 (average 5 experiments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Approximate TSNE with beta= 0.0>\n",
      "------------------\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "<mean>\n",
      "1-nn:/ 0.9400499999999999\n",
      "10-nn:/ 0.9476666666666667\n",
      "50-nn:/ 0.9401416666666668\n",
      "NMI: 0.7534086003990317\n",
      "SC: 0.44065276682376864\n",
      "overlap_1nn: 0.27275750000000004\n",
      "overlap_10nn 0.33728774999999994\n",
      "overlap_50nn 0.33005575\n",
      "<std>\n",
      "1-nn:/ 0.0017220788470785873\n",
      "10-nn:/ 0.001734054337223722\n",
      "50-nn:/ 0.0022453810614483974\n",
      "NMI: 0.020243921358309814\n",
      "SC: 0.010305848707925096\n",
      "overlap_1nn: 0.002240006975435572\n",
      "overlap_10nn 0.0006912465641867674\n",
      "overlap_50nn 0.0006803166266526256\n"
     ]
    }
   ],
   "source": [
    "print('<Approximate TSNE with beta= 0.0>')\n",
    "print('------------------')\n",
    "\n",
    "n = 10\n",
    "result = []\n",
    "for i in range(0, n):\n",
    "\n",
    "    perm = torch.randperm(cols)\n",
    "    shuffled_X_cat = X_cat[:, perm]\n",
    "\n",
    "    Y = approximate_Y(beta = 0.0)\n",
    "    distance = estimate_x_dist(shuffled_X_cat, Y).cpu().detach().numpy()\n",
    "    result.append(tsne_res(shuffled_X_cat.t(), distance, y_train[perm]))\n",
    "\n",
    "final_mean = np.mean(result, 0)\n",
    "final_var = np.var(result, 0)**(1/2)\n",
    "\n",
    "k_values = [1, 10, 50]\n",
    "\n",
    "# mean\n",
    "print('<mean>')\n",
    "for i in range(0, 3):\n",
    "    print(f'{k_values[i]}-nn:/ {final_mean[i]}')\n",
    "print('NMI:', final_mean[3])\n",
    "print('SC:', final_mean[4])\n",
    "print('overlap_1nn:', final_mean[5])\n",
    "print('overlap_10nn', final_mean[6])\n",
    "print('overlap_50nn', final_mean[7])\n",
    "\n",
    "# std\n",
    "print('<std>')\n",
    "for i in range(0, 3):\n",
    "    print(f'{k_values[i]}-nn:/ {final_var[i]}')\n",
    "print('NMI:', final_var[3])\n",
    "print('SC:', final_var[4])\n",
    "print('overlap_1nn:', final_var[5])\n",
    "print('overlap_10nn', final_var[6])\n",
    "print('overlap_50nn', final_var[7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Approximate TSNE with beta= 1.0>\n",
      "------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "Epoch finished\n",
      "<mean>\n",
      "1-nn:/ 0.936375\n",
      "10-nn:/ 0.9442833333333333\n",
      "50-nn:/ 0.9353583333333333\n",
      "NMI: 0.7470868719520543\n",
      "SC: 0.4477608770132065\n",
      "overlap_1nn: 0.25434750000000006\n",
      "overlap_10nn 0.32634474999999996\n",
      "overlap_50nn 0.32582175\n",
      "<std>\n",
      "1-nn:/ 0.0019522956572541343\n",
      "10-nn:/ 0.0011548208134213298\n",
      "50-nn:/ 0.0021848754096184704\n",
      "NMI: 0.0072797452354834745\n",
      "SC: 0.006557980548155653\n",
      "overlap_1nn: 0.001552032941016394\n",
      "overlap_10nn 0.00045153827357157807\n",
      "overlap_50nn 0.0006541754065233722\n"
     ]
    }
   ],
   "source": [
    "print('<Approximate TSNE with beta= 1.0>')\n",
    "print('------------------')\n",
    "\n",
    "n = 10\n",
    "result = []\n",
    "for i in range(0, n):\n",
    "    perm = torch.randperm(cols)\n",
    "    shuffled_X_cat = X_cat[:, perm]\n",
    "\n",
    "    Y = approximate_Y(beta = 1.0)\n",
    "    distance = estimate_x_dist(shuffled_X_cat, Y).cpu().detach().numpy()\n",
    "    result.append(tsne_res(shuffled_X_cat.t(), distance, y_train[perm]))\n",
    "\n",
    "final_mean = np.mean(result, 0)\n",
    "final_var = np.var(result, 0)**(1/2)\n",
    "\n",
    "k_values = [1, 10, 50]\n",
    "\n",
    "# mean\n",
    "print('<mean>')\n",
    "for i in range(0, 3):\n",
    "    print(f'{k_values[i]}-nn:/ {final_mean[i]}')\n",
    "print('NMI:', final_mean[3])\n",
    "print('SC:', final_mean[4])\n",
    "print('overlap_1nn:', final_mean[5])\n",
    "print('overlap_10nn', final_mean[6])\n",
    "print('overlap_50nn', final_mean[7])\n",
    "\n",
    "# std\n",
    "print('<std>')\n",
    "for i in range(0, 3):\n",
    "    print(f'{k_values[i]}-nn:/ {final_var[i]}')\n",
    "print('NMI:', final_var[3])\n",
    "print('SC:', final_var[4])\n",
    "print('overlap_1nn:', final_var[5])\n",
    "print('overlap_10nn', final_var[6])\n",
    "print('overlap_50nn', final_var[7])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyOdGSZao/j9LMFQffT25NSD",
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
