{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efa9a988",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"code\")\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from scipy import optimize\n",
    "\n",
    "from scipy.io import savemat\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "neigh = KNeighborsClassifier(n_neighbors=1)\n",
    "\n",
    "from sklearn.datasets import make_moons\n",
    "\n",
    "dataset = 2\n",
    "import mnist_reader \n",
    "    \n",
    "X_train, y_train = mnist_reader.load_mnist('data/mnist', kind='train')\n",
    "X_test, y_test = mnist_reader.load_mnist('data/mnist', kind='t10k')\n",
    "\n",
    "X_train = X_train.astype(np.float32)\n",
    "max_val = np.max(X_train)\n",
    "X_train = X_train/max_val\n",
    "\n",
    "X_test = X_test/max_val\n",
    "\n",
    "n = X_train.shape[0]\n",
    "\n",
    "classes = [\n",
    "    'T-shirt/top',\n",
    "    'Trouser',\n",
    "    'Pullover',\n",
    "    'Dress',\n",
    "    'Coat',\n",
    "    'Sandal',\n",
    "    'Shirt',\n",
    "    'Sneaker',\n",
    "    'Bag',\n",
    "    'Ankle boot']\n",
    "\n",
    "print(X_train.shape, y_train.shape, X_train.dtype)\n",
    "\n",
    "#Torch Setups\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "n_components = 2\n",
    "\n",
    "pca = PCA(n_components = n_components)\n",
    "x_init = pca.fit_transform(X_train)\n",
    "x_init = x_init - np.mean(x_init, axis=0)\n",
    "\n",
    "def print_stats(X):\n",
    "    print('size: ', X.shape)\n",
    "    print('Mean:', np.mean(X))\n",
    "    print('Max: ', np.max(X))\n",
    "    print('Min: ', np.min(X))\n",
    "    print('STD: ', np.std(X))\n",
    "    \n",
    "    return\n",
    "\n",
    "print('Training Statistics')\n",
    "print_stats(X_train)\n",
    "print('Test Statistics')\n",
    "print_stats(X_test)\n",
    "\n",
    "\n",
    "epochs = 200\n",
    "n_neighbors= 15\n",
    "n_components = 2\n",
    "MIN_DIST = 0.1\n",
    "\n",
    "    \n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c828e10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3833bf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('figure_neg/neg_t_sne_mnist_alpha_beta.npy', 'rb') as f:\n",
    "    embs = np.load(f)\n",
    "    alphas = np.load(f)\n",
    "    betas = np.load(f)\n",
    "    sil_scores = np.load(f)\n",
    "    t_scores = np.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aff6b9a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675ad2b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b1856eb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d57a464",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "print(np.max(sil_scores[1:,1:]), np.min(sil_scores[1:,1:]))\n",
    "print(np.max(t_scores[1:,1:]), np.min(t_scores[1:,1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "926561ef",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from helper_codes import plot_matrix, plot_rows\n",
    "\n",
    "plot_matrix(embs,alphas,betas,sil_scores,t_scores,y_train,\n",
    "            vs=True, vmins= 0.1, vmaxs= 0.5,\n",
    "            vt=True, vmint=0.7, vmaxt=0.96,\n",
    "            neg_y_axis=True,savename='figure_neg/neg_mnist')\n",
    "\n",
    "plot_rows(embs,alphas,betas,sil_scores,t_scores,y_train,neg_y_axis=True,savename='figure_neg/neg_mnist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43594ac5",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "embs2 = embs.reshape(len(alphas),len(betas),X_train.shape[0],2)\n",
    "# Get the default figure size\n",
    "default_figsize = plt.rcParams['figure.figsize']\n",
    "\n",
    "N_width = len(alphas)\n",
    "N_height = len(betas)\n",
    "\n",
    "# Calculate the new figure size (9 times the default)\n",
    "new_figsize_width = default_figsize[0] * N_width\n",
    "new_figsize_height = default_figsize[1] * N_height\n",
    "\n",
    "# Create a figure with the new size\n",
    "plt.figure(figsize=(new_figsize_width, new_figsize_height))\n",
    "\n",
    "k=0\n",
    "for i in range(0,len(alphas)):\n",
    "    colum_id = i\n",
    "    for j in range(0, len(betas)):\n",
    "        row_id = len(betas) - j\n",
    "        k+=1\n",
    "        ax = plt.subplot(N_height, N_width, k)\n",
    "        \n",
    "        idx_ch = i*len(alphas)+j\n",
    "        plt.scatter(embs2[i,j][:,0], embs2[i,j][:,1], c=y_train, s=0.1, cmap='Spectral')\n",
    "        plt.title(r'$k_1$='+str(alphas[i])+' $k_2$='+str(betas[j]))\n",
    "        \n",
    "\n",
    "\n",
    "plt.savefig('figure_neg/neg_mnist_2.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c51d55ac",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "<h1>fmnist</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aec9cf0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from helper_codes import plot_matrix, plot_rows\n",
    "\n",
    "fX_train, fy_train = mnist_reader.load_mnist('data/fashion', kind='train')\n",
    "\n",
    "with open('figure_neg_fmnist/neg_t_sne_fmnist_alpha_beta.npy', 'rb') as f:\n",
    "    fembs = np.load(f)\n",
    "    falphas = np.load(f)\n",
    "    fbetas = np.load(f)\n",
    "    fsil_scores = np.load(f)\n",
    "    ft_scores = np.load(f)\n",
    "    \n",
    "    \n",
    "print(np.max(fsil_scores[1:,1:]), np.min(fsil_scores[1:,1:]))\n",
    "print(np.max(ft_scores[1:,1:]), np.min(ft_scores[1:,1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6acb647e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_matrix(fembs,falphas,fbetas,fsil_scores,ft_scores,fy_train,\n",
    "            vs=True, vmins=0.04, vmaxs=0.2,\n",
    "            vt=True, vmint=0.85, vmaxt=0.97,\n",
    "            savename='figure_neg_fmnist/neg_fmnist')\n",
    "\n",
    "plot_rows(fembs,falphas,fbetas,fsil_scores,ft_scores,fy_train,savename='figure_neg_fmnist/neg_fmnist')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4034fd46",
   "metadata": {},
   "source": [
    "<h1>macosko</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27a461a8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from scipy.io import savemat, loadmat\n",
    "\n",
    "d = loadmat('data/macosko.mat')\n",
    "\n",
    "X_data = d['X_data']\n",
    "y_data = d['y_data'].reshape(-1)\n",
    "\n",
    "mX_train = X_data\n",
    "my_train = y_data\n",
    "\n",
    "with open('figure_neg_macosko/neg_t_sne_macosko_alpha_beta.npy', 'rb') as f:\n",
    "    membs = np.load(f)\n",
    "    malphas = np.load(f)\n",
    "    mbetas = np.load(f)\n",
    "    msil_scores = np.load(f)\n",
    "    mt_scores = np.load(f)\n",
    "    \n",
    "    \n",
    "print(np.max(msil_scores[1:,1:]), np.min(msil_scores[1:,1:]))\n",
    "print(np.max(mt_scores[1:,1:]), np.min(mt_scores[1:,1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b401505",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_matrix(membs,malphas,mbetas,msil_scores,mt_scores,\n",
    "            my_train,\n",
    "            vs=True, vmins=-0.02, vmaxs=0.45,\n",
    "            vt=True, vmint=0.75, vmaxt=0.95,\n",
    "            savename='figure_neg_macosko/neg_macosko')\n",
    "\n",
    "plot_rows(membs,malphas,mbetas,msil_scores,mt_scores,my_train,savename='figure_neg_macosko/neg_macosko')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "158f663e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d0ec6c0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pysegPy3.10",
   "language": "python",
   "name": "pysegpy3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
