{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cea4e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../code\")\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from scipy import optimize\n",
    "\n",
    "from scipy.io import savemat\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "neigh = KNeighborsClassifier(n_neighbors=1)\n",
    "\n",
    "from sklearn.datasets import make_moons\n",
    "\n",
    "dataset = 2\n",
    "import mnist_reader \n",
    "    \n",
    "X_train, y_train = mnist_reader.load_mnist('../data/fashion', kind='train')\n",
    "X_test, y_test = mnist_reader.load_mnist('../data/fashion', kind='t10k')\n",
    "\n",
    "#n = 2000\n",
    "#np.random.seed(6333)\n",
    "#ids = np.random.choice(range(60000),n, replace=False)\n",
    "#X_train, y_train = X_train[ids], y_train[ids]\n",
    "\n",
    "X_train = X_train.astype(np.float32)\n",
    "max_val = np.max(X_train)\n",
    "X_train = X_train/max_val\n",
    "\n",
    "X_test = X_test/max_val\n",
    "\n",
    "n = X_train.shape[0]\n",
    "\n",
    "classes = [\n",
    "    'T-shirt/top',\n",
    "    'Trouser',\n",
    "    'Pullover',\n",
    "    'Dress',\n",
    "    'Coat',\n",
    "    'Sandal',\n",
    "    'Shirt',\n",
    "    'Sneaker',\n",
    "    'Bag',\n",
    "    'Ankle boot']\n",
    "\n",
    "print(X_train.shape, y_train.shape, X_train.dtype)\n",
    "\n",
    "#Torch Setups\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "#import torch\n",
    "#device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "#X_torch = torch.as_tensor(X_train, dtype=torch.float32, device=device)\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "pca = PCA(n_components = n_components)\n",
    "x_init = pca.fit_transform(X_train)\n",
    "x_init = x_init - np.mean(x_init, axis=0)\n",
    "\n",
    "def print_stats(X):\n",
    "    print('size: ', X.shape)\n",
    "    print('Mean:', np.mean(X))\n",
    "    print('Max: ', np.max(X))\n",
    "    print('Min: ', np.min(X))\n",
    "    print('STD: ', np.std(X))\n",
    "    \n",
    "    return\n",
    "\n",
    "print('Training Statistics')\n",
    "print_stats(X_train)\n",
    "#print('Test Statistics')\n",
    "#print_stats(X_test)\n",
    "\n",
    "\n",
    "epochs = 200\n",
    "n_neighbors= 15\n",
    "n_components = 2\n",
    "MIN_DIST = 0.1\n",
    "\n",
    "    \n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e6924d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "pca = PCA(n_components = n_components)\n",
    "init = pca.fit_transform(X_train)\n",
    "emb = init.astype(np.float32).copy()\n",
    "\n",
    "neg_sample_rate = 5\n",
    "repulsion_strength=1.0\n",
    "\n",
    "print(emb.shape)\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(emb[:,0], emb[:,1], c=y_train, s=0.01, cmap='Spectral')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d846eb38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import numba\n",
    "from numba import prange\n",
    "\n",
    "import random\n",
    "\n",
    "import scipy.sparse\n",
    "\n",
    "import gc\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True, parallel=True)\n",
    "def euclidean_distances_numba(X, squared = True):\n",
    "    n = X.shape[0]\n",
    "    xcorr = np.zeros((n,n),dtype=X.dtype)\n",
    "    for i in prange(n):\n",
    "        for j in range(i,n):\n",
    "            dist = np.sum( np.square(X[i,:] - X[j,:]) )\n",
    "            if not squared:\n",
    "                dist = np.sqrt(dist)\n",
    "            xcorr[i,j] = dist\n",
    "            xcorr[j,i] = dist\n",
    "    \n",
    "    return xcorr\n",
    "\n",
    "#@numba.jit(nopython=True)\n",
    "def get_weight_function(dists, rho, sigma):\n",
    "    d = dists - rho\n",
    "    #print(d)\n",
    "    d[d<0] = 0\n",
    "    weight = np.exp(- d / sigma )\n",
    "    return weight\n",
    "\n",
    "#@numba.jit(nopython=True)\n",
    "def search_sigma(dists, rho, k, tol = 10**-5, n_iteration=200):\n",
    "    sigma_min = 0\n",
    "    sigma_max = 1000\n",
    "    \n",
    "    cur_sigma = 100\n",
    "    \n",
    "    logk = np.log2(k)\n",
    "    #print(logk)\n",
    "    \n",
    "    for i in range(n_iteration):\n",
    "        \n",
    "        cur_sigma = (sigma_min+sigma_max)/2\n",
    "        probs = get_weight_function(dists,rho,cur_sigma)\n",
    "        weight = np.sum(probs)\n",
    "        #print(weight)\n",
    "        \n",
    "        if np.abs(logk - weight) < tol:\n",
    "            break\n",
    "        \n",
    "        if weight < logk:\n",
    "            sigma_min = cur_sigma\n",
    "        else:\n",
    "            sigma_max = cur_sigma\n",
    "        \n",
    "    return cur_sigma, probs\n",
    "\n",
    "@numba.jit(nopython=True, parallel=True)\n",
    "def symmetrization_step(prob):\n",
    "    n = prob.shape[0]\n",
    "    P = np.zeros((n,n),dtype=np.float32)\n",
    "\n",
    "    for i in prange(n):\n",
    "        #if i%1000 == 0:\n",
    "        #    print('Completed ', i, ' of ', n)\n",
    "        for j in prange(i,n):\n",
    "            p = prob[i,j] + prob[j,i] - prob[i,j] * prob[j,i] #t-conorm\n",
    "            P[i,j] = p\n",
    "            P[j,i] = p\n",
    "            \n",
    "    return P\n",
    "\n",
    "def get_prob_matrix(X, n_neighbors=15):\n",
    "    n = X.shape[0]\n",
    "    dist = euclidean_distances_numba(X, squared = False)\n",
    "    sort_idx = np.argsort(dist,axis=1)\n",
    "    #sort_idx = sort_idx.astype(np.int32)\n",
    "    sort_idx = sort_idx[:,1:n_neighbors+1]\n",
    "    \n",
    "    rho = [ dist[i, sort_idx[i,0] ] for i in range(n)]\n",
    "    rho = np.array(rho)\n",
    "    \n",
    "    \n",
    "\n",
    "    sigmas = []\n",
    "\n",
    "    directed_graph = []\n",
    "\n",
    "\n",
    "    #'''\n",
    "    for i in range(n):\n",
    "        if (i+1)%1000 == 0:\n",
    "            print('Processed ', i+1, ' of ', n, ' samples.')\n",
    "        sigma, weights = search_sigma(dists = dist[i,sort_idx[i,:]],rho = rho[i],k = n_neighbors)\n",
    "\n",
    "        probs = np.zeros(n)\n",
    "        probs[sort_idx[i,:]] = weights\n",
    "        #print(sum(weights), np.log2(n_neighbors))\n",
    "        #print(sort_idx[i,:])\n",
    "        #print(probs[1770:1780])\n",
    "\n",
    "        directed_graph.append(probs)\n",
    "\n",
    "    directed_graph = np.array(directed_graph).astype(np.float32)\n",
    "    prob = directed_graph\n",
    "    \n",
    "    P = symmetrization_step(prob)\n",
    "    #P = prob\n",
    "    \n",
    "    graph = scipy.sparse.coo_matrix(P)\n",
    "    \n",
    "    return graph\n",
    "\n",
    "def make_epochs_per_sample(weights, n_epochs):\n",
    "    \"\"\"Given a set of weights and number of epochs generate the number of\n",
    "    epochs per sample for each weight.\n",
    "    Parameters\n",
    "    ----------\n",
    "    weights: array of shape (n_1_simplices)\n",
    "        The weights ofhow much we wish to sample each 1-simplex.\n",
    "    n_epochs: int\n",
    "        The total number of epochs we want to train for.\n",
    "    Returns\n",
    "    -------\n",
    "    An array of number of epochs per sample, one for each 1-simplex.\n",
    "    Copied from UMAP repo: https://github.com/lmcinnes/umap/\n",
    "    \"\"\"\n",
    "    result = -1.0 * np.ones(weights.shape[0], dtype=np.float64)\n",
    "    n_samples = n_epochs * (weights / weights.max())\n",
    "    result[n_samples > 0] = float(n_epochs) / n_samples[n_samples > 0]\n",
    "    return result\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e5e3eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "graph = get_prob_matrix(X_train,n_neighbors=n_neighbors)\n",
    "print(len(graph.data))\n",
    "print('prune value: ', graph.data.max() / float(epochs))\n",
    "graph.data[graph.data < (graph.data.max() / float(epochs))] = 0.0\n",
    "graph.eliminate_zeros()\n",
    "print(len(graph.data))\n",
    "epochs_per_sample_og = make_epochs_per_sample(graph.data, epochs)\n",
    "gc.collect()\n",
    "\n",
    "with open('FMNIST_epoch_per_sample_og.npy', 'wb') as f:\n",
    "    np.save(f, epochs_per_sample_og)\n",
    "\n",
    "from scipy import sparse\n",
    "graph = sparse.save_npz('FMNIST_graph.npz', graph)\n",
    "\n",
    "#'''\n",
    "\n",
    "with open('FMNIST_epoch_per_sample_og.npy', 'rb') as f:\n",
    "    epochs_per_sample_og = np.load(f)\n",
    "\n",
    "from scipy import sparse\n",
    "graph = sparse.load_npz('FMNIST_graph.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae7b3b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(epochs_per_sample_og.data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1fb6880",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c621b84e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42908467",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "x = np.linspace(0, 3, 300)\n",
    "\n",
    "y = (x>MIN_DIST) * np.exp(-x+MIN_DIST)\n",
    "y[x<=MIN_DIST] = 1.0\n",
    "\n",
    "function = lambda x, a, b: 1 / (1 + a*x**(2*b))\n",
    "\n",
    "p , _ = optimize.curve_fit(function, x, y) \n",
    "\n",
    "a = p[0]\n",
    "b = p[1] \n",
    "print(\"Hyperparameters a = \" + str(a) + \" and b = \" + str(b))\n",
    "\n",
    "x_p = np.arange(0,3,0.01)\n",
    "y_p = np.exp(- (x_p-MIN_DIST) * ( (x_p - MIN_DIST) >=0 ) )\n",
    "y_p2 = 1 / (1 + a*x_p**(2*b))\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(x_p,y_p, label='Target')\n",
    "plt.plot(x_p,y_p2, label='Fitted')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f344cefc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ed258c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a1b90ad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "attr_coeff = []\n",
    "rep_coeff = []\n",
    "\n",
    "idx_to_map = 100\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def clip(x,val=4.0):\n",
    "\n",
    "    if x>val:\n",
    "        return val\n",
    "    elif x<-val:\n",
    "        return -val\n",
    "    else:\n",
    "        return x\n",
    "    \n",
    "@numba.jit(nopython=True)\n",
    "def update_attraction(x, y, a, b, dim, lr, P, epoch, idx):\n",
    "    dist = np.sum((x - y)**2)\n",
    "\n",
    "    if dist>0.0:\n",
    "        grad_coeff = 2*a*b*dist**(b-1.0) / (1 + a * dist**b) \n",
    "    else:\n",
    "        grad_coeff = 0.0\n",
    "\n",
    "    for d in range(dim):\n",
    "        mv = clip(grad_coeff * P * (x[0,d]-y[0,d])) \n",
    "        mv = mv * lr\n",
    "\n",
    "        x[0,d] -= mv\n",
    "        y[0,d] += mv\n",
    "        \n",
    "    return\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def update_repulsion(x, y, a, b, dim, lr, P, epoch, idx):\n",
    "    dist = np.sum((x - y)**2)\n",
    "\n",
    "    if dist>0.0:\n",
    "        grad_coeff = 2 * repulsion_strength * b / ( (0.001+dist) * (1.0 + a * dist**b) )\n",
    "    else:\n",
    "        grad_coeff = 0\n",
    "        \n",
    "\n",
    "    for d in range(dim):\n",
    "\n",
    "        grad = clip(grad_coeff  * (x[0,d]-y[0,d]) * (1-P))\n",
    "        mv = grad * lr\n",
    "\n",
    "        x[0,d] += mv\n",
    "\n",
    "    return\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def one_step_in_a_set(emb, idx, rows, columns, a, b, dim,\n",
    "                   n_points,\n",
    "                   epochs_per_sample,\n",
    "                   epoch_of_next_sample,\n",
    "                   epochs_per_negative_sample,\n",
    "                   epoch_of_next_negative_sample,\n",
    "                   lr, epoch):\n",
    "    \n",
    "    if epoch_of_next_sample[idx] <= epoch:\n",
    "        x_idx = rows[idx]\n",
    "        y_idx = columns[idx]\n",
    "        \n",
    "        \n",
    "        x = emb[x_idx:x_idx+1,:]\n",
    "        y = emb[y_idx:y_idx+1, :]\n",
    "            \n",
    "        update_attraction(x, y, a, b, dim, lr, 1, epoch, idx)\n",
    "        \n",
    "        epoch_of_next_sample[idx] += epochs_per_sample[idx]\n",
    "        \n",
    "        n_neg_samples = int(\n",
    "                (epoch - epoch_of_next_negative_sample[idx]) / epochs_per_negative_sample[idx]\n",
    "            )\n",
    "        \n",
    "        for i in range(n_neg_samples):\n",
    "            y_idx = np.random.choice(n_points)\n",
    "            \n",
    "            if x_idx == y_idx:\n",
    "                continue\n",
    "            \n",
    "            y = emb[y_idx:y_idx+1, :]\n",
    "                \n",
    "            update_repulsion(x, y, a, b, dim, lr, 0, epoch, idx)\n",
    "            \n",
    "        epoch_of_next_negative_sample[idx] += (\n",
    "                n_neg_samples * epochs_per_negative_sample[idx]\n",
    "            )\n",
    "            \n",
    "    return \n",
    "\n",
    "@numba.jit(nopython=True,parallel=True)\n",
    "def one_epoch_2sets_2(emb,\n",
    "                     rows, columns,\n",
    "                     n_points,\n",
    "                     n_edges,\n",
    "                     a, b, dim,\n",
    "                     lr, epoch,\n",
    "                     epochs_per_sample,\n",
    "                     epoch_of_next_sample,\n",
    "                     epochs_per_negative_sample,\n",
    "                     epoch_of_next_negative_sample,\n",
    "                     repulsion_strength=1.0):\n",
    "    '''\n",
    "    Set1 = 1 * np.ones(epochs_per_sample_1.shape[0])\n",
    "    Set2 = 2 * np.ones(epochs_per_sample_2.shape[0])\n",
    "    Set = np.random.permutation(np.concatenate((Set1,Set2)))\n",
    "    '''\n",
    "    \n",
    "    for i in prange(n_edges):\n",
    "\n",
    "        one_step_in_a_set(emb=emb, idx=i, \n",
    "                              rows=rows, columns=columns, a=a, b=b, dim=dim,\n",
    "                              n_points=n_points,\n",
    "                              epochs_per_sample=epochs_per_sample,\n",
    "                              epoch_of_next_sample=epoch_of_next_sample,\n",
    "                              epochs_per_negative_sample=epochs_per_negative_sample,\n",
    "                              epoch_of_next_negative_sample=epoch_of_next_negative_sample,\n",
    "                              lr=lr, epoch=epoch)\n",
    "    \n",
    "    return\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "neg_sample_rate = 5\n",
    "repulsion_strength=1.0\n",
    "\n",
    "plt.figure(figsize=(25,5))\n",
    "\n",
    "with open('random_init_test_orig/fmnist_seeds.npy', 'rb') as f:\n",
    "    seeds = np.load(f)\n",
    "\n",
    "no_of_random_init = 100 #len(seeds)\n",
    "\n",
    "embeddings_m1 = []\n",
    "    \n",
    "for nr in range(no_of_random_init):\n",
    "\n",
    "    #pca = PCA(n_components = n_components)\n",
    "    #init = pca.fit_transform(X_train)\n",
    "    np.random.seed(seeds[nr])\n",
    "    init = np.random.randn(len(X_train),2)*10\n",
    "    emb = init.astype(np.float32).copy()\n",
    "    expansion = 10.0 / np.abs(emb).max()\n",
    "    emb = (emb * expansion).astype(np.float32)\n",
    "\n",
    "\n",
    "    epochs_per_sample = epochs_per_sample_og.copy()\n",
    "    epoch_of_next_sample = epochs_per_sample.copy()\n",
    "    epochs_per_negative_sample = epochs_per_sample / neg_sample_rate\n",
    "    epoch_of_next_negative_sample = epochs_per_negative_sample.copy()\n",
    "\n",
    "    init_lr = 1.0\n",
    "\n",
    "    n_edges = epochs_per_sample.shape[0]\n",
    "\n",
    "    np.random.seed(500)\n",
    "\n",
    "    import timeit\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "\n",
    "        if epoch%100==0:\n",
    "            print(nr, ': epoch ', epoch, 'of ', epochs)\n",
    "        #print('epoch ', epoch, 'of ', epochs)\n",
    "        start = timeit.default_timer()\n",
    "\n",
    "        lr = init_lr * (1.0 - float(epoch)/float(epochs))\n",
    "\n",
    "        one_epoch_2sets_2(emb=emb, \n",
    "                         rows=graph.row, columns=graph.col, \n",
    "                         n_points=X_train.shape[0],\n",
    "                         n_edges = n_edges,\n",
    "                         a=a, b=b, dim=n_components,\n",
    "                         lr=lr, epoch=epoch+1,\n",
    "                         epochs_per_sample=epochs_per_sample,\n",
    "                         epoch_of_next_sample=epoch_of_next_sample,\n",
    "                         epochs_per_negative_sample=epochs_per_negative_sample,\n",
    "                         epoch_of_next_negative_sample=epoch_of_next_negative_sample,\n",
    "                         repulsion_strength=repulsion_strength)\n",
    "\n",
    "        stop = timeit.default_timer()\n",
    "        #print('Time for epoch ', epoch, ': ', stop - start) \n",
    "        \n",
    "        \n",
    "    #plt.subplot(1,no_of_random_init,nr+1)\n",
    "    #plt.title('nonLinear Transform UMAP')\n",
    "    #plt.scatter(emb[:,0], emb[:,1], c=y_train, s=0.01, cmap='Spectral')\n",
    "    #cbar = plt.colorbar(boundaries=np.arange(11)-0.5)\n",
    "    \n",
    "    embeddings_m1.append(emb)\n",
    "    \n",
    "with open('random_init_test_orig/fmnist_random_init.npy', 'wb') as f:\n",
    "    np.save(f, embeddings_m1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "218fef88",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f01a8312",
   "metadata": {},
   "source": [
    "<h1>Statistics</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a33634c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial import procrustes\n",
    "\n",
    "def plot_idxs(array,idxs,figsize=(10,10), title=None, values = None, tick_off=True):\n",
    "    \n",
    "    plt.figure(figsize=figsize)\n",
    "    \n",
    "    if title is not None:\n",
    "        plt.title(title)\n",
    "        \n",
    "    n_plot = int(np.sqrt(len(idxs)))\n",
    "    \n",
    "    for i in range(n_plot**2):\n",
    "        plt.subplot(n_plot, n_plot, i+1)\n",
    "        plt.scatter(array[idxs[i],:,0], array[idxs[i],:,1], c=y_train, s=0.01, cmap='Spectral')\n",
    "        if values is not None:\n",
    "            plt.title(str(values[i]))\n",
    "        if tick_off:\n",
    "            plt.xticks([])\n",
    "            plt.yticks([])\n",
    "        \n",
    "    return\n",
    "\n",
    "def plot_low_k_idxs(array, metric, k, title=None):\n",
    "\n",
    "    idxs_arg = np.argpartition(metric, k)[:k]\n",
    "    values = metric[idxs_arg]\n",
    "\n",
    "    idxs_arg_pointer = np.argsort(values)\n",
    "\n",
    "    values = values[idxs_arg_pointer]\n",
    "    idxs_arg = idxs_arg[idxs_arg_pointer]\n",
    "\n",
    "    plot_idxs(array, idxs_arg, title=title, values=values)\n",
    "    \n",
    "    return\n",
    "\n",
    "def plot_high_k_idxs(array, metric, k, title=None):\n",
    "\n",
    "    idxs_arg = np.argpartition(metric, -k)[-k:]\n",
    "    values = metric[idxs_arg]\n",
    "\n",
    "    idxs_arg_pointer = np.argsort(values)\n",
    "\n",
    "    values = values[idxs_arg_pointer]\n",
    "    idxs_arg = idxs_arg[idxs_arg_pointer]\n",
    "\n",
    "    plot_idxs(array, idxs_arg, title=title, values=values)\n",
    "    \n",
    "    return\n",
    "\n",
    "def procrustes_distances(standard_array, array): \n",
    "    pds = []\n",
    "    X_pdx = []\n",
    "\n",
    "    for i in range(len(array)):\n",
    "        _,x_pd,d = procrustes(standard_array, array[i])\n",
    "        pds.append(d)\n",
    "        X_pdx.append(x_pd)\n",
    "\n",
    "    pds = np.array(pds)\n",
    "    X_pdx = np.array(X_pdx)\n",
    "    print('Procrusted Distance: Mean: ', np.mean(pds), ' STD: ', np.std(pds))\n",
    "    \n",
    "    return pds, X_pdx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd1c300c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = 'random_init_test_orig'\n",
    "\n",
    "with open('random_init_test_orig/fmnist_PCA_init.npy', 'rb') as f:\n",
    "    umap_pca = np.load(f)\n",
    "    \n",
    "print(umap_pca.shape)\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(umap_pca[:,0], umap_pca[:,1], c=y_train, s=0.01, cmap='Spectral')\n",
    "cbar = plt.colorbar(boundaries=np.arange(11)-0.5)\n",
    "cbar.set_ticks(np.arange(10))\n",
    "cbar.set_ticklabels(classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a6fd54c",
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs = np.random.choice(100,9,replace=False)\n",
    "embds = np.array(embeddings_m1)\n",
    "plot_idxs(embds,idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c688be",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_r2, umap_pdr2 = procrustes_distances(umap_pca, embds)\n",
    "\n",
    "print('mean', np.mean(pd_r2), ', standard deviatioin:', np.std(pd_r2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e5f1593",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_low_k_idxs(umap_pdr2, pd_r2, 9, title='UMAP Random Init - Low')\n",
    "plot_high_k_idxs(umap_pdr2, pd_r2, 9, title='UMAP Random Init - High')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25f29870",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
