{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The clustering experiment on the Fashion-MNIST data set"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from DW import DW\n",
    "from utils2 import sampled_sphere\n",
    "from utils2 import Tukey_Depth, Projection_Depth, SW, Sinkhorn\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ot\n",
    "from sklearn.metrics.cluster import normalized_mutual_info_score\n",
    "from sklearn.metrics import adjusted_rand_score\n",
    "from sklearn.cluster import SpectralClustering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiment functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_dist_matrix(X):\n",
    "    n_distrib, n_samples, d = X.shape\n",
    "    weights = np.zeros(n_samples) + 1/ n_samples\n",
    "    DM_DR = torch.zeros(n_distrib,n_distrib)\n",
    "    DM_W = torch.zeros(n_distrib,n_distrib)\n",
    "    DM_SW = torch.zeros(n_distrib,n_distrib)\n",
    "    DM_MMD = torch.zeros(n_distrib,n_distrib)\n",
    "    for i in range(n_distrib):\n",
    "        for j in range(i):\n",
    "            DM_DR[i,j] = DW(X[i].numpy(), X[j].numpy(), ndirs=100 ,  data_depth='Projection'  )      \n",
    "            DM_SW[i,j] = SW(X[i].numpy(), X[j].numpy(), ndirs=100 )\n",
    "            M = ot.dist(X[i].numpy(), X[j].numpy())\n",
    "            DM_W[i,j] = ot.emd2(weights,weights,M )\n",
    "            DM_MMD[i,j] = mmd(X[i], X[j])\n",
    "    \n",
    "    DM_DR = DM_DR + DM_DR.T\n",
    "    DM_W = DM_W + DM_W.T\n",
    "    DM_SW = DM_SW + DM_SW.T\n",
    "    DM_MMD = DM_MMD + DM_MMD.T\n",
    "    \n",
    "    return DM_DR, DM_W,DM_SW,DM_MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bag_of_pixels_gray(X,size=100):\n",
    "    ## Bag of pixels\n",
    "    Z = torch.zeros(50000, 28*28, 3)\n",
    "    for i in range(28):\n",
    "        for j in range(28):\n",
    "            Z[:,i*28 + j, 0] = i\n",
    "            Z[:,i*28+ j, 1] = j\n",
    "            Z[:,i*28+ j, 2] = X[0][:,0,i,j]\n",
    "    data = []\n",
    "    for i in range(10):\n",
    "        data.append(Z[X[1] == i])\n",
    "    ZZ = torch.zeros(10*size, 784, 3)\n",
    "    for i in range(10):\n",
    "        ZZ[(i*size):(i+1) * size,:,:] = data[i][:size]\n",
    "    return ZZ\n",
    "def bag_of_pixels_color(X,size=100):\n",
    "    ## Bag of pixels\n",
    "    Z = torch.zeros(50000, 28*28, 5)\n",
    "    for i in range(28):\n",
    "        for j in range(28):\n",
    "            Z[:,i*28 + j, 0] = i\n",
    "            Z[:,i*28+ j, 1] = j\n",
    "            Z[:,i*28+ j, 2] = X[0][:,0,i,j]\n",
    "            Z[:,i*28+ j, 3] = X[0][:,1,i,j]\n",
    "            Z[:,i*28+ j, 4] = X[0][:,2,i,j]\n",
    "    data = []\n",
    "    for i in range(10):\n",
    "        data.append(Z[X[1] == i])\n",
    "    ZZ = torch.zeros(10*size, 784, 5)\n",
    "    for i in range(10):\n",
    "        ZZ[(i*size):(i+1) * size,:,:] = data[i][:size]\n",
    "    return ZZ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clustering_result(R, y_true):\n",
    "    n_methods = len(R)\n",
    "    result = np.zeros((n_methods,2))\n",
    "    for i in range(n_methods):\n",
    "        Clustering = KMeans(n_clusters=10, random_state=0, precompute_distances=True)\n",
    "        Clustering.fit(R[i])\n",
    "        result[i,0] = normalized_mutual_info_score(y_true,Clustering.labels_)\n",
    "        result[i,1] = adjusted_rand_score(y_true,Clustering.labels_)\n",
    "    return result\n",
    "def Sclustering_result_KM(R, y_true):\n",
    "    n_methods = len(R)\n",
    "    result = np.zeros((n_methods,2))\n",
    "    for i in range(n_methods):\n",
    "        Clustering = SpectralClustering(n_clusters=10, random_state=0, affinity='precomputed')\n",
    "        Clustering.fit( 1 / (1 + R[i]))\n",
    "        result[i,0] = normalized_mutual_info_score(y_true,Clustering.labels_)\n",
    "        result[i,1] = adjusted_rand_score(y_true,Clustering.labels_)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mmd(x1, x2, sigma=1):\n",
    "    x1x1 = gaussian_kernel(x1, x1, sigma)\n",
    "    x1x2 = gaussian_kernel(x1, x2, sigma)\n",
    "    x2x2 = gaussian_kernel(x2, x2, sigma)\n",
    "    diff = x1x1.mean() - 2 * x1x2.mean() + x2x2.mean()\n",
    "    return diff\n",
    "\n",
    "def gaussian_kernel(x1, x2, sigma = 1.0):\n",
    "    #r = x1.dimshuffle(0,'x',1)\n",
    "    return np.exp(-np.linalg.norm(x1-x2, axis=1)/(2*sigma **2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5,), (0.5, ))])\n",
    "trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=50000,\n",
    "                                          shuffle=False, num_workers=0)\n",
    "QQ = iter(trainloader)\n",
    "FM = QQ.next()\n",
    "X_FM = bag_of_pixels_gray(FM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing distribution distances, load results if you don't want to re-compute it, it may take around 14h of computation time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1d 4h 7min 21s, sys: 2d 1h 32min 14s, total: 3d 5h 39min 35s\n",
      "Wall time: 14h 18min 7s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "np.random.seed(0)\n",
    "## Uncomment this if you want to compute the distance matrix\n",
    "#S = compute_dist_matrix(X_FM)\n",
    "\n",
    "## Comment this if you want to compute the distance matrix\n",
    "S = torch.load('clustering_FM.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.array([0,1,2,3,4,5,6,7,8,9])\n",
    "y_true = np.repeat(y, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.57776882 0.43417426]\n",
      " [0.54851991 0.39325595]\n",
      " [0.49686503 0.35347194]\n",
      " [0.53809939 0.36972508]]\n"
     ]
    }
   ],
   "source": [
    "result_FMKM  = Sclustering_result_KM(S, y_true)\n",
    "print(result_FMKM);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing direct clustering (Euclidian in the table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "FM_real_clustering_ = FM[0].reshape(50000,28*28)\n",
    "data = []\n",
    "for i in range(10):\n",
    "    data.append(FM_real_clustering_[FM[1] == i])\n",
    "FM_real_clustering = torch.zeros(1000, 784)\n",
    "for i in range(10):\n",
    "    FM_real_clustering[(i*100):(i+1) * 100,:] = data[i][:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4997447157674264\n",
      "0.322703704433886\n",
      "CPU times: user 1.12 s, sys: 1.28 s, total: 2.4 s\n",
      "Wall time: 204 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "M = ot.dist(FM_real_clustering)\n",
    "Clustering = SpectralClustering(n_clusters=10, random_state=0, affinity='precomputed')\n",
    "Clustering.fit( 1 / (1 + M))\n",
    "NMI_normal = normalized_mutual_info_score(y_true,Clustering.labels_)\n",
    "ARS_normal = adjusted_rand_score(y_true,Clustering.labels_)\n",
    "print(NMI_normal)\n",
    "print(ARS_normal)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
