{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The clustering experiment on the Contaminated Fashion-MNIST data set"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from DW import DW\n",
    "from utils2 import sampled_sphere\n",
    "from utils2 import Tukey_Depth, Projection_Depth, SW, Sinkhorn\n",
    "from sklearn.cluster import SpectralClustering\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ot\n",
    "from sklearn.metrics.cluster import normalized_mutual_info_score\n",
    "from sklearn.metrics import adjusted_rand_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiment functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_dist_matrix(X):\n",
    "    n_distrib, n_samples, d = X.shape\n",
    "    weights = np.zeros(n_samples) + 1/ n_samples\n",
    "    DM_DR = torch.zeros(n_distrib,n_distrib)\n",
    "    DM_W = torch.zeros(n_distrib,n_distrib)\n",
    "    DM_SW = torch.zeros(n_distrib,n_distrib)\n",
    "    DM_MMD = torch.zeros(n_distrib,n_distrib)\n",
    "    for i in range(n_distrib):\n",
    "        for j in range(i):\n",
    "            DM_DR[i,j] = DW(X[i].numpy(), X[j].numpy(), ndirs=100 ,  data_depth='Projection'  )\n",
    "\n",
    "\n",
    "  \n",
    "            DM_SW[i,j] = SW(X[i].numpy(), X[j].numpy(), ndirs=100 )\n",
    "        \n",
    "            \n",
    "            M = ot.dist(X[i].numpy(), X[j].numpy())\n",
    "            DM_W[i,j] = ot.emd2(weights,weights,M )\n",
    "            DM_MMD[i,j] = mmd(X[i], X[j])\n",
    "    \n",
    "    DM_DR = DM_DR + DM_DR.T\n",
    "    DM_W = DM_W + DM_W.T\n",
    "    DM_SW = DM_SW + DM_SW.T\n",
    "    DM_MMD = DM_MMD + DM_MMD.T\n",
    "    \n",
    "    return DM_DR, DM_W,DM_SW,DM_MMD\n",
    "\n",
    "def bag_of_pixels_gray(X,size=100):\n",
    "    ## Bag of pixels\n",
    "    Z = torch.zeros(50000, 28*28, 3)\n",
    "    for i in range(28):\n",
    "        for j in range(28):\n",
    "            Z[:,i*28 + j, 0] = i\n",
    "            Z[:,i*28+ j, 1] = j\n",
    "            Z[:,i*28+ j, 2] = X[0][:,0,i,j]\n",
    "    data = []\n",
    "    for i in range(10):\n",
    "        data.append(Z[X[1] == i])\n",
    "    ZZ = torch.zeros(10*size, 784, 3)\n",
    "    for i in range(10):\n",
    "        ZZ[(i*size):(i+1) * size,:,:] = data[i][:size]\n",
    "    return ZZ\n",
    "def clustering_result(R, y_true):\n",
    "    n_methods = len(R)\n",
    "    result = np.zeros((n_methods,2))\n",
    "    for i in range(n_methods):\n",
    "        Clustering = KMeans(n_clusters=10, random_state=0, precompute_distances=True)\n",
    "        Clustering.fit(R[i])\n",
    "        result[i,0] = normalized_mutual_info_score(y_true,Clustering.labels_[:1000])\n",
    "        result[i,1] = adjusted_rand_score(y_true,Clustering.labels_[:1000])\n",
    "    return result\n",
    "def clustering_result_anom(R, y_true):\n",
    "    n_methods = len(R)\n",
    "    result = np.zeros(n_methods,2)\n",
    "    for i in range(n_methods):\n",
    "        Clustering = KMeans(n_clusters=10, random_state=0, precompute_distances=True)\n",
    "        Clustering.fit(R[i])\n",
    "        result[i,0] = normalized_mutual_info_score(y_true,Clustering.labels_[:1000])\n",
    "        result[i,1] = adjusted_rand_score(y_true,Clustering.labels_[:1000])\n",
    "    return result\n",
    "def mmd(x1, x2, sigma=1):\n",
    "    x1x1 = gaussian_kernel(x1, x1, sigma)\n",
    "    x1x2 = gaussian_kernel(x1, x2, sigma)\n",
    "    x2x2 = gaussian_kernel(x2, x2, sigma)\n",
    "    diff = x1x1.mean() - 2 * x1x2.mean() + x2x2.mean()\n",
    "    return diff\n",
    "\n",
    "def gaussian_kernel(x1, x2, sigma = 1.0):\n",
    "    #r = x1.dimshuffle(0,'x',1)\n",
    "    return np.exp(-np.linalg.norm(x1-x2, axis=1)/(2*sigma **2))\n",
    "def Sclustering_result_KM(R, y_true):\n",
    "    n_methods = len(R)\n",
    "    result = np.zeros((n_methods,2))\n",
    "    for i in range(n_methods):\n",
    "        Clustering = SpectralClustering(n_clusters=10, random_state=0, affinity='precomputed')\n",
    "        Clustering.fit( 1 / (1 + R[i]))\n",
    "        result[i,0] = normalized_mutual_info_score(y_true,Clustering.labels_[:1000])\n",
    "        result[i,1] = adjusted_rand_score(y_true,Clustering.labels_[:1000])\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5,), (0.5, ))])\n",
    "trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=50000,\n",
    "                                          shuffle=False, num_workers=0)\n",
    "QQ = iter(trainloader)\n",
    "FM = QQ.next()\n",
    "X_FM = bag_of_pixels_gray(FM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the contaminated data set adding a white patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "noise = torch.randn(100,100)\n",
    "subsample = np.random.choice(1000,50)\n",
    "subsample_im = np.array([4,5,6,7,8,9,10,11,12,13])\n",
    "X_FM_anom = X_FM.clone()\n",
    "a = 0\n",
    "for s in subsample_im:\n",
    "    b = 0\n",
    "    for q in subsample_im:\n",
    "        X_FM_anom[subsample, s*28 + q , 2 ] = 1\n",
    "        b += 1\n",
    "    a += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing distributiuons distances, load results if you don't want to re-compute it, it may take around 14h of computation time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2d 21h 15min 46s, sys: 12h 43min 11s, total: 3d 9h 58min 57s\n",
      "Wall time: 10h 46min 28s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "np.random.seed(0)\n",
    "\n",
    "## Uncomment this if you want to compute the distance matrix\n",
    "#S = compute_dist_matrix(X_FM_anom)\n",
    "\n",
    "## Comment this if you want to compute the distance matrix\n",
    "S = torch.load('Robust_clustering_FM.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.54743537 0.41676395]\n",
      " [0.48054403 0.30647336]\n",
      " [0.46932183 0.33082909]\n",
      " [0.50797205 0.35972347]]\n"
     ]
    }
   ],
   "source": [
    "y = np.array([0,1,2,3,4,5,6,7,8,9])\n",
    "y_true = np.repeat(y, 100)\n",
    "result_FMKM  = Sclustering_result_KM(S, y_true)\n",
    "print(result_FMKM);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing direct clustering (Euclidian in the table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "FM_real_clustering_ = FM[0].reshape(50000,28*28)\n",
    "data = []\n",
    "for i in range(10):\n",
    "    data.append(FM_real_clustering_[FM[1] == i])\n",
    "FM_real_clustering = torch.zeros(1000, 784)\n",
    "for i in range(10):\n",
    "    FM_real_clustering[(i*100):(i+1) * 100,:] = data[i][:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "FM_real_clustering_anom = FM_real_clustering.clone()\n",
    "a = 0\n",
    "for s in subsample_im:\n",
    "    b = 0\n",
    "    for q in subsample_im:\n",
    "        FM_real_clustering_anom[subsample, s*28 + q ] = 1\n",
    "        b += 1\n",
    "    a += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4834653462701407\n",
      "0.3044650200297357\n",
      "CPU times: user 1.23 s, sys: 1.38 s, total: 2.61 s\n",
      "Wall time: 242 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "M = ot.dist(FM_real_clustering_anom)\n",
    "Clustering = SpectralClustering(n_clusters=10, random_state=0, affinity='precomputed')\n",
    "Clustering.fit( 1 / (1 + M))\n",
    "NMI_normal = normalized_mutual_info_score(y_true,Clustering.labels_)\n",
    "ARS_normal = adjusted_rand_score(y_true,Clustering.labels_)\n",
    "print(NMI_normal)\n",
    "print(ARS_normal)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
