{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cea4e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"code\")\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from scipy import optimize\n",
    "\n",
    "from scipy.io import savemat\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "neigh = KNeighborsClassifier(n_neighbors=1)\n",
    "\n",
    "from sklearn.datasets import make_moons\n",
    "\n",
    "dataset = 2\n",
    "import mnist_reader \n",
    "    \n",
    "X_train, y_train = mnist_reader.load_mnist('data/mnist', kind='train')\n",
    "X_test, y_test = mnist_reader.load_mnist('data/mnist', kind='t10k')\n",
    "\n",
    "X_train = X_train.astype(np.float32)\n",
    "max_val = np.max(X_train)\n",
    "X_train = X_train/max_val\n",
    "\n",
    "X_test = X_test/max_val\n",
    "\n",
    "n = X_train.shape[0]\n",
    "\n",
    "classes = [\n",
    "    'T-shirt/top',\n",
    "    'Trouser',\n",
    "    'Pullover',\n",
    "    'Dress',\n",
    "    'Coat',\n",
    "    'Sandal',\n",
    "    'Shirt',\n",
    "    'Sneaker',\n",
    "    'Bag',\n",
    "    'Ankle boot']\n",
    "\n",
    "print(X_train.shape, y_train.shape, X_train.dtype)\n",
    "\n",
    "#Torch Setups\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "pca = PCA(n_components = n_components)\n",
    "x_init = pca.fit_transform(X_train)\n",
    "x_init = x_init - np.mean(x_init, axis=0)\n",
    "\n",
    "def print_stats(X):\n",
    "    print('size: ', X.shape)\n",
    "    print('Mean:', np.mean(X))\n",
    "    print('Max: ', np.max(X))\n",
    "    print('Min: ', np.min(X))\n",
    "    print('STD: ', np.std(X))\n",
    "    \n",
    "    return\n",
    "\n",
    "print('Training Statistics')\n",
    "print_stats(X_train)\n",
    "print('Test Statistics')\n",
    "print_stats(X_test)\n",
    "\n",
    "\n",
    "epochs = 200\n",
    "n_neighbors= 15\n",
    "n_components = 2\n",
    "MIN_DIST = 0.1\n",
    "\n",
    "    \n",
    "%matplotlib notebook\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d846eb38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import numba\n",
    "from numba import prange\n",
    "\n",
    "import random\n",
    "\n",
    "import scipy.sparse\n",
    "\n",
    "import gc\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True, parallel=True)\n",
    "def euclidean_distances_numba(X, squared = True):\n",
    "    n = X.shape[0]\n",
    "    xcorr = np.zeros((n,n),dtype=X.dtype)\n",
    "    for i in prange(n):\n",
    "        for j in range(i,n):\n",
    "            dist = np.sum( np.square(X[i,:] - X[j,:]) )\n",
    "            if not squared:\n",
    "                dist = np.sqrt(dist)\n",
    "            xcorr[i,j] = dist\n",
    "            xcorr[j,i] = dist\n",
    "    \n",
    "    return xcorr\n",
    "\n",
    "#@numba.jit(nopython=True)\n",
    "def get_weight_function(dists, rho, sigma):\n",
    "    d = dists - rho\n",
    "    #print(d)\n",
    "    d[d<0] = 0\n",
    "    weight = np.exp(- d / sigma )\n",
    "    return weight\n",
    "\n",
    "#@numba.jit(nopython=True)\n",
    "def search_sigma(dists, rho, k, tol = 10**-5, n_iteration=200):\n",
    "    sigma_min = 0\n",
    "    sigma_max = 1000\n",
    "    \n",
    "    cur_sigma = 100\n",
    "    \n",
    "    logk = np.log2(k)\n",
    "    #print(logk)\n",
    "    \n",
    "    for i in range(n_iteration):\n",
    "        \n",
    "        cur_sigma = (sigma_min+sigma_max)/2\n",
    "        probs = get_weight_function(dists,rho,cur_sigma)\n",
    "        weight = np.sum(probs)\n",
    "        #print(weight)\n",
    "        \n",
    "        if np.abs(logk - weight) < tol:\n",
    "            break\n",
    "        \n",
    "        if weight < logk:\n",
    "            sigma_min = cur_sigma\n",
    "        else:\n",
    "            sigma_max = cur_sigma\n",
    "        \n",
    "    return cur_sigma, probs\n",
    "\n",
    "@numba.jit(nopython=True, parallel=True)\n",
    "def symmetrization_step(prob):\n",
    "    n = prob.shape[0]\n",
    "    P = np.zeros((n,n),dtype=np.float32)\n",
    "\n",
    "    for i in prange(n):\n",
    "        #if i%1000 == 0:\n",
    "        #    print('Completed ', i, ' of ', n)\n",
    "        for j in prange(i,n):\n",
    "            p = prob[i,j] + prob[j,i] - prob[i,j] * prob[j,i] #t-conorm\n",
    "            P[i,j] = p\n",
    "            P[j,i] = p\n",
    "            \n",
    "    return P\n",
    "\n",
    "def get_prob_matrix(X, n_neighbors=15):\n",
    "    n = X.shape[0]\n",
    "    dist = euclidean_distances_numba(X, squared = False)\n",
    "    sort_idx = np.argsort(dist,axis=1)\n",
    "    #sort_idx = sort_idx.astype(np.int32)\n",
    "    sort_idx = sort_idx[:,1:n_neighbors+1]\n",
    "    \n",
    "    rho = [ dist[i, sort_idx[i,0] ] for i in range(n)]\n",
    "    rho = np.array(rho)\n",
    "    \n",
    "    \n",
    "\n",
    "    sigmas = []\n",
    "\n",
    "    directed_graph = []\n",
    "\n",
    "\n",
    "    #'''\n",
    "    for i in range(n):\n",
    "        if (i+1)%1000 == 0:\n",
    "            print('Processed ', i+1, ' of ', n, ' samples.')\n",
    "        sigma, weights = search_sigma(dists = dist[i,sort_idx[i,:]],rho = rho[i],k = n_neighbors)\n",
    "\n",
    "        probs = np.zeros(n)\n",
    "        probs[sort_idx[i,:]] = weights\n",
    "        #print(sum(weights), np.log2(n_neighbors))\n",
    "        #print(sort_idx[i,:])\n",
    "        #print(probs[1770:1780])\n",
    "\n",
    "        directed_graph.append(probs)\n",
    "\n",
    "    directed_graph = np.array(directed_graph).astype(np.float32)\n",
    "    prob = directed_graph\n",
    "    \n",
    "    P = symmetrization_step(prob)\n",
    "    #P = prob\n",
    "    \n",
    "    graph = scipy.sparse.coo_matrix(P)\n",
    "    \n",
    "    return graph\n",
    "\n",
    "def make_epochs_per_sample(weights, n_epochs):\n",
    "    \"\"\"Given a set of weights and number of epochs generate the number of\n",
    "    epochs per sample for each weight.\n",
    "    Parameters\n",
    "    ----------\n",
    "    weights: array of shape (n_1_simplices)\n",
    "        The weights ofhow much we wish to sample each 1-simplex.\n",
    "    n_epochs: int\n",
    "        The total number of epochs we want to train for.\n",
    "    Returns\n",
    "    -------\n",
    "    An array of number of epochs per sample, one for each 1-simplex.\n",
    "    Copied from UMAP repo: https://github.com/lmcinnes/umap/\n",
    "    \"\"\"\n",
    "    result = -1.0 * np.ones(weights.shape[0], dtype=np.float64)\n",
    "    n_samples = n_epochs * (weights / weights.max())\n",
    "    result[n_samples > 0] = float(n_epochs) / n_samples[n_samples > 0]\n",
    "    return result\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e5e3eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "#graph = get_prob_matrix(X_train,n_neighbors=n_neighbors)\n",
    "#print(len(graph.data))\n",
    "#print('prune value: ', graph.data.max() / float(epochs))\n",
    "#graph.data[graph.data < (graph.data.max() / float(epochs))] = 0.0\n",
    "#graph.eliminate_zeros()\n",
    "#print(len(graph.data))\n",
    "#epochs_per_sample_og = make_epochs_per_sample(graph.data, epochs)\n",
    "#gc.collect()\n",
    "\n",
    "\n",
    "with open('MNIST_epoch_per_sample_og.npy', 'rb') as f:\n",
    "    epochs_per_sample_og = np.load(f)\n",
    "\n",
    "from scipy import sparse\n",
    "graph = sparse.load_npz('MNIST_graph.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae7b3b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(epochs_per_sample_og.data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c621b84e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "pca = PCA(n_components = n_components)\n",
    "init = pca.fit_transform(X_train)\n",
    "emb = init.astype(np.float32).copy()\n",
    "\n",
    "neg_sample_rate = 5\n",
    "repulsion_strength=1.0\n",
    "\n",
    "print(emb.shape)\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(emb[:,0], emb[:,1], c=y_train, s=0.01, cmap='Spectral')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42908467",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(0, 3, 300)\n",
    "\n",
    "MIN_DIST = 0.1\n",
    "\n",
    "y = (x>MIN_DIST) * np.exp(-x+MIN_DIST)\n",
    "y[x<=MIN_DIST] = 1.0\n",
    "\n",
    "function = lambda x, a, b: 1 / (1 + a*x**(2*b))\n",
    "\n",
    "p , _ = optimize.curve_fit(function, x, y) \n",
    "\n",
    "a = p[0]\n",
    "b = p[1] \n",
    "print(\"Hyperparameters a = \" + str(a) + \" and b = \" + str(b))\n",
    "\n",
    "x_p = np.arange(0,3,0.01)\n",
    "y_p = np.exp(- (x_p-MIN_DIST) * ( (x_p - MIN_DIST) >=0 ) )\n",
    "y_p2 = 1 / (1 + a*x_p**(2*b))\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(x_p,y_p, label=r'$p_{ij}$')\n",
    "plt.plot(x_p,y_p2, label=r'$q_{ij}$')\n",
    "plt.xlabel('distance')\n",
    "plt.ylabel('weight')\n",
    "plt.legend()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ed258c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "attr_coeff = []\n",
    "rep_coeff = []\n",
    "\n",
    "idx_to_map = 100\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def clip(x,val=4.0):\n",
    "\n",
    "    if x>val:\n",
    "        return val\n",
    "    elif x<-val:\n",
    "        return -val\n",
    "    else:\n",
    "        return x\n",
    "    \n",
    "@numba.jit(nopython=True)\n",
    "def update_attraction(x, y, a, b, dim, lr, P, idx, alpha):\n",
    "    dist = np.sum((x - y)**2)\n",
    "\n",
    "    if dist>0.0:\n",
    "        grad_coeff = alpha * 2*a*b*dist**(b-1.0) / (1 + a * dist**b)\n",
    "        \n",
    "\n",
    "    else:\n",
    "        grad_coeff = 0.0\n",
    "\n",
    "\n",
    "    for d in range(dim):\n",
    "        mv = clip(grad_coeff * P * (x[0,d]-y[0,d]))  # * P[idx,idy]\n",
    "        mv = mv * lr\n",
    "\n",
    "        x[0,d] -= mv\n",
    "        y[0,d] += mv\n",
    "        \n",
    "    return\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def update_repulsion(x, y, a, b, dim, lr, P, idx, beta):\n",
    "    dist = np.sum((x - y)**2)\n",
    "\n",
    "\n",
    "    if dist>0.0:\n",
    "        grad_coeff = beta * 2 * repulsion_strength * b / ( (0.001+dist) * (1.0 + a * dist**b) )\n",
    "    else:\n",
    "        grad_coeff = 0\n",
    "        \n",
    "    for d in range(dim):\n",
    "        grad = clip(grad_coeff  * (x[0,d]-y[0,d]) * (1-P))\n",
    "        mv = grad * lr\n",
    "\n",
    "        x[0,d] += mv\n",
    "        #y[0,d] -= mv\n",
    "\n",
    "    return\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def one_step_in_a_set(emb, idx, rows, columns, a, b, dim,\n",
    "                   n_points,\n",
    "                   epochs_per_sample,\n",
    "                   epoch_of_next_sample,\n",
    "                   epochs_per_negative_sample,\n",
    "                   epoch_of_next_negative_sample,\n",
    "                   lr, epoch,\n",
    "                     alpha, beta):\n",
    "    \n",
    "    if epoch_of_next_sample[idx] <= epoch:\n",
    "        x_idx = rows[idx]\n",
    "        y_idx = columns[idx]\n",
    "        \n",
    "        \n",
    "        x = emb[x_idx:x_idx+1,:]\n",
    "        y = emb[y_idx:y_idx+1, :]\n",
    "            \n",
    "        update_attraction(x, y, a, b, dim, lr, 1, idx, alpha=alpha)\n",
    "        \n",
    "        epoch_of_next_sample[idx] += epochs_per_sample[idx]\n",
    "        \n",
    "        n_neg_samples = int(\n",
    "                (epoch - epoch_of_next_negative_sample[idx]) / epochs_per_negative_sample[idx]\n",
    "            )\n",
    "        \n",
    "        for i in range(n_neg_samples):\n",
    "            y_idx = np.random.choice(n_points)\n",
    "            \n",
    "            if x_idx == y_idx:\n",
    "                continue\n",
    "            \n",
    "            y = emb[y_idx:y_idx+1, :]\n",
    "                \n",
    "            update_repulsion(x, y, a, b, dim, lr, 0, idx, beta=beta)\n",
    "            \n",
    "        epoch_of_next_negative_sample[idx] += (\n",
    "                n_neg_samples * epochs_per_negative_sample[idx]\n",
    "            )\n",
    "            \n",
    "    return \n",
    "\n",
    "@numba.jit(nopython=True,parallel=True)\n",
    "def one_epoch(emb,\n",
    "                     rows, columns,\n",
    "                     n_points,\n",
    "                     n_edges,\n",
    "                     a, b, dim,\n",
    "                     lr, epoch,\n",
    "                     epochs_per_sample,\n",
    "                     epoch_of_next_sample,\n",
    "                     epochs_per_negative_sample,\n",
    "                     epoch_of_next_negative_sample,\n",
    "                     alpha, beta,\n",
    "                     repulsion_strength=1.0):\n",
    "    '''\n",
    "    Set1 = 1 * np.ones(epochs_per_sample_1.shape[0])\n",
    "    Set2 = 2 * np.ones(epochs_per_sample_2.shape[0])\n",
    "    Set = np.random.permutation(np.concatenate((Set1,Set2)))\n",
    "    '''\n",
    "    \n",
    "    for i in prange(n_edges):\n",
    "\n",
    "        one_step_in_a_set(emb=emb, idx=i, \n",
    "                              rows=rows, columns=columns, a=a, b=b, dim=dim,\n",
    "                              n_points=n_points,\n",
    "                              epochs_per_sample=epochs_per_sample,\n",
    "                              epoch_of_next_sample=epoch_of_next_sample,\n",
    "                              epochs_per_negative_sample=epochs_per_negative_sample,\n",
    "                              epoch_of_next_negative_sample=epoch_of_next_negative_sample,\n",
    "                              lr=lr, epoch=epoch,\n",
    "                              alpha=alpha, beta=beta)\n",
    "    \n",
    "    return\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.metrics import silhouette_score as SIL_score\n",
    "from sklearn.manifold import trustworthiness\n",
    "\n",
    "np.random.seed(2398734219)\n",
    "chosen_idx = np.random.choice(50000, size=10000, replace=False)\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "alphas = np.arange(0,1.1,0.1)\n",
    "betas  = np.arange(0,1.1,0.1)\n",
    "\n",
    "embs = []\n",
    "sil_scores = np.zeros((len(alphas), len(betas)))\n",
    "t_scores   = np.zeros((len(alphas), len(betas)))\n",
    "\n",
    "sil_dict = {}\n",
    "t_dict = {}\n",
    "\n",
    "for i in range(len(alphas)):\n",
    "    for j in range(len(betas)):\n",
    "        print('Completing:', i, j)\n",
    "        pca = PCA(n_components = n_components)\n",
    "        init = pca.fit_transform(X_train)\n",
    "        emb = init.astype(np.float32).copy()\n",
    "        expansion = 10.0 / np.abs(emb).max()\n",
    "        emb = (emb * expansion).astype(np.float32)\n",
    "\n",
    "        neg_sample_rate = 5\n",
    "        repulsion_strength=1.0\n",
    "\n",
    "\n",
    "        epochs_per_sample = epochs_per_sample_og.copy()\n",
    "        epoch_of_next_sample = epochs_per_sample.copy()\n",
    "        epochs_per_negative_sample = epochs_per_sample / neg_sample_rate\n",
    "        epoch_of_next_negative_sample = epochs_per_negative_sample.copy()\n",
    "\n",
    "        init_lr = 1.0\n",
    "\n",
    "        n_edges = epochs_per_sample.shape[0]\n",
    "\n",
    "        np.random.seed(500)\n",
    "\n",
    "        a = 1.58\n",
    "        b = 0.89\n",
    "\n",
    "\n",
    "        import timeit\n",
    "\n",
    "\n",
    "\n",
    "        for epoch in range(epochs):\n",
    "\n",
    "            #if epoch%20==0:\n",
    "            #    print('epoch ', epoch, 'of ', epochs)\n",
    "\n",
    "            #start = timeit.default_timer()\n",
    "\n",
    "            lr = init_lr #* 0.1 #* (1.0 - float(epoch)/float(epochs))\n",
    "\n",
    "            one_epoch(emb=emb, \n",
    "                             rows=graph.row, columns=graph.col, \n",
    "                             n_points=X_train.shape[0],\n",
    "                             n_edges = n_edges,\n",
    "                             a=a, b=b, dim=n_components,\n",
    "                             lr=lr, epoch=epoch+1,\n",
    "                             epochs_per_sample=epochs_per_sample,\n",
    "                             epoch_of_next_sample=epoch_of_next_sample,\n",
    "                             epochs_per_negative_sample=epochs_per_negative_sample,\n",
    "                             epoch_of_next_negative_sample=epoch_of_next_negative_sample,\n",
    "                             repulsion_strength=repulsion_strength,\n",
    "                             alpha=alphas[i], beta=betas[j])\n",
    "\n",
    "\n",
    "            #stop = timeit.default_timer()\n",
    "        embs.append(emb)\n",
    "        sil_scores[i,j] = SIL_score(emb,y_train)\n",
    "        t_scores[i,j] = trustworthiness(X_train[chosen_idx], emb[chosen_idx])\n",
    "        sil_dict[alphas[i],betas[j]] = sil_scores[i,j]\n",
    "        t_dict[alphas[i],betas[j]] = t_scores[i,j]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39942df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#with open('random_init_test_orig/mnist_PCA_init.npy', 'wb') as f:\n",
    "#    np.save(f,emb)\n",
    "\n",
    "#embs = []\n",
    "#embs.append(emb)\n",
    "#print(len(embs))\n",
    "\n",
    "print(sil_scores)\n",
    "\n",
    "print(t_dict[0.1,0.1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb4f7c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "#embs = np.array(embs)\n",
    "#print(embs.shape)\n",
    "\n",
    "with open('mnist_alpha_beta.npy', 'wb') as f:\n",
    "    np.save(f, embs)\n",
    "    np.save(f,alphas)\n",
    "    np.save(f,betas)\n",
    "    np.save(f,sil_scores)\n",
    "    np.save(f,t_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a1b90ad",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.imshow(sil_scores[1:,1:].T, cmap='brg', origin='lower')\n",
    "plt.xlabel(r'$k_1$')\n",
    "plt.ylabel(r'$k_2$')\n",
    "plt.xticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10)\n",
    "plt.yticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10)\n",
    "#plt.yticks([0,2,4,6,8,10],np.array([0,2,4,6,8,10])/10)\n",
    "#plt.xticks([0,2,4,6,8,10],np.array([0,2,4,6,8,10])/10)\n",
    "plt.colorbar()\n",
    "plt.savefig('figure/mnist_k1k2_silhouette.eps')\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(t_scores[1:,1:].T, cmap='brg', origin='lower')\n",
    "plt.xlabel(r'$k_1$')\n",
    "plt.ylabel(r'$k_2$')\n",
    "plt.xticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10)\n",
    "plt.yticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10)\n",
    "#plt.yticks([0,2,4,6,8,10],np.array([0,2,4,6,8,10])/10)\n",
    "#plt.xticks([0,2,4,6,8,10],np.array([0,2,4,6,8,10])/10)\n",
    "plt.colorbar()\n",
    "plt.savefig('figure/mnist_k1k2_trustworthiness.eps')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "226e085d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#%matplotlib notebook\n",
    "\n",
    "from scale_bar import add_scalebar\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "plt.scatter(emb[:,0], emb[:,1], c=y_train, s=0.1, cmap='Spectral')\n",
    "plt.axis('off')\n",
    "add_scalebar(ax)\n",
    "#plt.title(r'UMAP of MNIST')\n",
    "#plt.savefig('figure/UMAP_1fa_1fr.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e975d2fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import silhouette_score as SIL_score\n",
    "from sklearn.manifold import trustworthiness\n",
    "\n",
    "print(SIL_score(emb, y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a33634c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(2398734219)\n",
    "chosen_idx = np.random.choice(50000, size=10000, replace=False)\n",
    "print(trustworthiness(X_train[chosen_idx], emb[chosen_idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd1c300c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a6fd54c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pysegPy3.10",
   "language": "python",
   "name": "pysegpy3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
