{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efa9a988",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"code\")\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from scipy import optimize\n",
    "\n",
    "from scipy.io import savemat\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "neigh = KNeighborsClassifier(n_neighbors=1)\n",
    "\n",
    "from sklearn.datasets import make_moons\n",
    "\n",
    "dataset = 2\n",
    "import mnist_reader \n",
    "    \n",
    "X_train, y_train = mnist_reader.load_mnist('data/mnist', kind='train')\n",
    "X_test, y_test = mnist_reader.load_mnist('data/mnist', kind='t10k')\n",
    "\n",
    "X_train = X_train.astype(np.float32)\n",
    "max_val = np.max(X_train)\n",
    "X_train = X_train/max_val\n",
    "\n",
    "X_test = X_test/max_val\n",
    "\n",
    "n = X_train.shape[0]\n",
    "\n",
    "classes = [\n",
    "    'T-shirt/top',\n",
    "    'Trouser',\n",
    "    'Pullover',\n",
    "    'Dress',\n",
    "    'Coat',\n",
    "    'Sandal',\n",
    "    'Shirt',\n",
    "    'Sneaker',\n",
    "    'Bag',\n",
    "    'Ankle boot']\n",
    "\n",
    "print(X_train.shape, y_train.shape, X_train.dtype)\n",
    "\n",
    "#Torch Setups\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "pca = PCA(n_components = n_components)\n",
    "x_init = pca.fit_transform(X_train)\n",
    "x_init = x_init - np.mean(x_init, axis=0)\n",
    "\n",
    "def print_stats(X):\n",
    "    print('size: ', X.shape)\n",
    "    print('Mean:', np.mean(X))\n",
    "    print('Max: ', np.max(X))\n",
    "    print('Min: ', np.min(X))\n",
    "    print('STD: ', np.std(X))\n",
    "    \n",
    "    return\n",
    "\n",
    "print('Training Statistics')\n",
    "print_stats(X_train)\n",
    "print('Test Statistics')\n",
    "print_stats(X_test)\n",
    "\n",
    "\n",
    "epochs = 200\n",
    "n_neighbors= 15\n",
    "n_components = 2\n",
    "MIN_DIST = 0.1\n",
    "\n",
    "    \n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c828e10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3833bf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('mnist_alpha_beta.npy', 'rb') as f:\n",
    "    embs = np.load(f)\n",
    "    alphas = np.load(f)\n",
    "    betas = np.load(f)\n",
    "    sil_scores = np.load(f)\n",
    "    t_scores = np.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aff6b9a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from scale_bar import add_scalebar\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "scatter = plt.scatter(embs[0][:,0], embs[0][:,1], c=y_train, s=0.1, cmap='Spectral')\n",
    "lgd = plt.legend(*scatter.legend_elements(),\n",
    "                    loc=\"upper left\",\n",
    "                    fontsize=20,\n",
    "                 ncol=1,\n",
    "                prop={'size': 8})\n",
    "\n",
    "add_scalebar(ax)\n",
    "\n",
    "fig.text(0.9, 0.8, 'SIL= {:0.3f}'.format(sil_scores[0,0]), fontsize=15, horizontalalignment='right')\n",
    "fig.text(0.9, 0.75, 'T= {:0.3f}'.format(t_scores[0,0]), fontsize=15, horizontalalignment='right')\n",
    "\n",
    "plt.title('PCA Initialization', fontsize=30)\n",
    "plt.savefig('figure/mnist_init.png', dpi=150)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675ad2b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b1856eb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d57a464",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "print(np.max(sil_scores[1:,1:]), np.min(sil_scores[1:,1:]))\n",
    "print(np.max(t_scores[1:,1:]), np.min(t_scores[1:,1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "926561ef",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from helper_codes import plot_matrix, plot_rows\n",
    "\n",
    "plot_matrix(embs,alphas,betas,sil_scores,t_scores,y_train,\n",
    "            vs=True, vmins= 0.1, vmaxs= 0.5,\n",
    "            vt=True, vmint=0.7, vmaxt=0.96,\n",
    "            savename='figure/mnist')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ee8901d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_rows(embs,alphas,betas,sil_scores,t_scores,y_train,savename='figure/mnist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43594ac5",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "embs2 = embs.reshape(len(alphas),len(betas),X_train.shape[0],2)\n",
    "# Get the default figure size\n",
    "default_figsize = plt.rcParams['figure.figsize']\n",
    "\n",
    "N_width = len(alphas)\n",
    "N_height = len(betas)\n",
    "\n",
    "# Calculate the new figure size (9 times the default)\n",
    "new_figsize_width = default_figsize[0] * N_width\n",
    "new_figsize_height = default_figsize[1] * N_height\n",
    "\n",
    "# Create a figure with the new size\n",
    "plt.figure(figsize=(new_figsize_width, new_figsize_height))\n",
    "\n",
    "k=0\n",
    "for i in range(0,len(alphas)):\n",
    "    colum_id = i\n",
    "    for j in range(0, len(betas)):\n",
    "        row_id = len(betas) - j\n",
    "        k+=1\n",
    "        ax = plt.subplot(N_height, N_width, k)\n",
    "        \n",
    "        idx_ch = i*len(alphas)+j\n",
    "        plt.scatter(embs2[i,j][:,0], embs2[i,j][:,1], c=y_train, s=0.1, cmap='Spectral')\n",
    "        plt.title(r'$k_1$='+str(alphas[i])+' $k_2$='+str(betas[j]))\n",
    "        \n",
    "\n",
    "\n",
    "plt.savefig('figure/mnist_2.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22c83379",
   "metadata": {},
   "source": [
    "<h1>Fasion MNIST</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01650920",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from helper_codes import plot_matrix, plot_rows\n",
    "\n",
    "fX_train, fy_train = mnist_reader.load_mnist('data/fashion', kind='train')\n",
    "\n",
    "with open('fmnist_alpha_beta.npy', 'rb') as f:\n",
    "    fembs = np.load(f)\n",
    "    falphas = np.load(f)\n",
    "    fbetas = np.load(f)\n",
    "    fsil_scores = np.load(f)\n",
    "    ft_scores = np.load(f)\n",
    "    \n",
    "print(np.max(fsil_scores), np.min(fsil_scores))\n",
    "print(np.max(ft_scores), np.min(ft_scores))\n",
    "    \n",
    "plot_matrix(fembs,falphas,fbetas,fsil_scores,ft_scores,fy_train,\n",
    "            vs=True, vmins=0.04, vmaxs=0.2,\n",
    "            vt=True, vmint=0.85, vmaxt=0.97,\n",
    "            savename='figure_fmnist/fmnist')\n",
    "\n",
    "plot_rows(fembs,falphas,fbetas,fsil_scores,ft_scores,fy_train,savename='figure_fmnist/fmnist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f5e5c85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6c8f6a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.max(fsil_scores[1:,:1]), np.min(fsil_scores[1:,1:]))\n",
    "print(np.max(ft_scores[1:,1:]), np.min(ft_scores[1:,1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aec9cf0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6acb647e",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [\n",
    "    'T-shirt/top',\n",
    "    'Trouser',\n",
    "    'Pullover',\n",
    "    'Dress',\n",
    "    'Coat',\n",
    "    'Sandal',\n",
    "    'Shirt',\n",
    "    'Sneaker',\n",
    "    'Bag',\n",
    "    'Ankle boot']\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "scatter = ax.scatter(fembs[0,:,0], fembs[0,:,1], c=fy_train, s=0.01, cmap='Spectral')\n",
    "\n",
    "lgd = plt.legend(*scatter.legend_elements(),\n",
    "                    loc=\"upper left\",\n",
    "                    fontsize=20,\n",
    "                 ncol=1,\n",
    "                prop={'size': 8})\n",
    "for j in range(len(lgd.get_texts())):\n",
    "    lgd.get_texts()[j].set_text(classes[j])\n",
    "add_scalebar(ax)\n",
    "\n",
    "fig.text(0.9, 0.8, 'SIL= {:0.3f}'.format(fsil_scores[0,0]), fontsize=15, horizontalalignment='right')\n",
    "fig.text(0.9, 0.75, 'T= {:0.3f}'.format(ft_scores[0,0]), fontsize=15, horizontalalignment='right')\n",
    "\n",
    "plt.title('PCA Initialization', fontsize=30)\n",
    "plt.savefig('figure_fmnist/fmnist_init.png', dpi=100)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc683c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "embs2 = fembs.reshape(len(falphas),len(fbetas),fX_train.shape[0],2)\n",
    "# Get the default figure size\n",
    "default_figsize = plt.rcParams['figure.figsize']\n",
    "\n",
    "N_width = len(falphas)\n",
    "N_height = len(fbetas)\n",
    "\n",
    "# Calculate the new figure size (9 times the default)\n",
    "new_figsize_width = default_figsize[0] * N_width\n",
    "new_figsize_height = default_figsize[1] * N_height\n",
    "\n",
    "# Create a figure with the new size\n",
    "plt.figure(figsize=(new_figsize_width, new_figsize_height))\n",
    "\n",
    "k=0\n",
    "for i in range(0,len(falphas)):\n",
    "    colum_id = i\n",
    "    for j in range(0, len(fbetas)):\n",
    "        row_id = len(fbetas) - j\n",
    "        k+=1\n",
    "        ax = plt.subplot(N_height, N_width, k)\n",
    "        \n",
    "        idx_ch = i*len(falphas)+j\n",
    "        plt.scatter(embs2[i,j][:,0], embs2[i,j][:,1], c=fy_train, s=0.1, cmap='Spectral')\n",
    "        plt.title(r'$k_1$='+str(falphas[i])+' $k_2$='+str(fbetas[j]))\n",
    "        \n",
    "\n",
    "\n",
    "plt.savefig('figure_fmnist/fmnist_2.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec41a915",
   "metadata": {},
   "source": [
    "<h1>Macosko</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27a461a8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from scipy.io import savemat, loadmat\n",
    "\n",
    "d = loadmat('data/macosko.mat')\n",
    "\n",
    "X_data = d['X_data']\n",
    "y_data = d['y_data'].reshape(-1)\n",
    "\n",
    "mX_train = X_data\n",
    "my_train = y_data\n",
    "\n",
    "with open('macosko_alpha_beta.npy', 'rb') as f:\n",
    "    membs = np.load(f)\n",
    "    malphas = np.load(f)\n",
    "    mbetas = np.load(f)\n",
    "    msil_scores = np.load(f)\n",
    "    mt_scores = np.load(f)\n",
    "\n",
    "print(np.max(msil_scores[1:,1:]), np.min(msil_scores[1:,1:]))\n",
    "print(np.max(mt_scores[1:,1:]), np.min(mt_scores[1:,1:]))\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79a13a08",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plot_matrix(membs,malphas,mbetas,msil_scores,mt_scores,my_train,\n",
    "            vs=True, vmins=-0.02, vmaxs=0.45,\n",
    "            vt=True, vmint=0.75, vmaxt=0.95,\n",
    "            savename='figure_macosko/macosko')\n",
    "\n",
    "plot_rows(membs,malphas,mbetas,msil_scores,mt_scores,my_train,savename='figure_macosko/macosko')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b401505",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "classes = ['Amacrine cells',\n",
    "           'Astrocytes',\n",
    "           'Bipolar cells',\n",
    "           'Cones',\n",
    "           'Fibroblasts',\n",
    "           'Horizontal cells',\n",
    "           'Microglia',\n",
    "           'Muller glia',\n",
    "           'Pericytes',\n",
    "           'Retinal ganglion cells',\n",
    "           'Rods',\n",
    "           'Vascular endothelium']\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "scatter = ax.scatter(membs[0,:,0], membs[0,:,1], c=my_train, s=0.01, cmap='Spectral')\n",
    "\n",
    "lgd = plt.legend(*scatter.legend_elements(),\n",
    "                    loc=\"upper left\",\n",
    "                    fontsize=20,\n",
    "                 ncol=1,\n",
    "                prop={'size': 8})\n",
    "for j in range(len(lgd.get_texts())):\n",
    "    lgd.get_texts()[j].set_text(classes[j])\n",
    "add_scalebar(ax)\n",
    "\n",
    "fig.text(0.9, 0.8, 'SIL= {:0.3f}'.format(msil_scores[0,0]), fontsize=15, horizontalalignment='right')\n",
    "fig.text(0.9, 0.75, 'T= {:0.3f}'.format(mt_scores[0,0]), fontsize=15, horizontalalignment='right')\n",
    "\n",
    "plt.title('PCA Initialization', fontsize=30)\n",
    "plt.savefig('figure_macosko/macosko_init.png', dpi=100)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "158f663e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d0ec6c0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pysegPy3.10",
   "language": "python",
   "name": "pysegpy3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
