{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST, CIFAR10\n",
    "from torchvision.utils import save_image\n",
    "from tensorboardX import SummaryWriter\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ConvNet1(nn.Module):\n",
    "    def __init__(self, kernel_size, padding, skip_connection):\n",
    "        super(ConvNet1, self).__init__()\n",
    "\n",
    "        self.layer1 = nn.Conv2d(in_channels=1, out_channels=512, kernel_size=kernel_size, stride=1, padding=padding, bias=False)\n",
    "        self.act1 = nn.ReLU()\n",
    "        self.layer2 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False)\n",
    "        self.skip = skip_connection\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.layer2(self.act1(self.layer1(x)))\n",
    "        if self.skip:\n",
    "            y += x\n",
    "        return y\n",
    "\n",
    "class UnrolledNet1(nn.Module):\n",
    "    def __init__(self, kernel_size=3, padding=1, skip_connection=False, unrolls=2):\n",
    "        super(UnrolledNet1, self).__init__()\n",
    "        self.convblocks = nn.ModuleList([ConvNet1(kernel_size, padding, skip_connection) for i in range(unrolls)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        for convnet in self.convblocks:\n",
    "            x = convnet(x)\n",
    "        return x\n",
    "    \n",
    "    def representation(self, x):\n",
    "        to_return = []\n",
    "        y = x\n",
    "        for convnet in self.convblocks:\n",
    "            inter = convnet.layer1(y)\n",
    "            to_return.append(inter)\n",
    "            y = convnet.layer2(convnet.act1(inter))\n",
    "        \n",
    "        return to_return\n",
    "    \n",
    "    def inputs(self, x):\n",
    "        to_return = [x]\n",
    "        for convnet in self.convblocks:\n",
    "            x = convnet(x)\n",
    "            to_return.append(x)\n",
    "        \n",
    "        return to_return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = nn.DataParallel(UnrolledNet1())\n",
    "model.load_state_dict(torch.load('../paper_results/MNIST_twolayer_conv_relu_0.75_3_2/.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = model.module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model.inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "normalize = transforms.Normalize((0.1307,), (0.3081,))\n",
    "bs=5\n",
    "train_dataset = MNIST(\n",
    "    '', train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "# extract datset in numpy in A and y matrices\n",
    "dummy_loader= torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=bs, shuffle=False,\n",
    "    pin_memory=True, sampler=None)\n",
    "\n",
    "for A, y in dummy_loader:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    activations = model.representation(A)\n",
    "    inputs = model.inputs(A)\n",
    "    outputs = model(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(outputs[0, 0], cmap='Greys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "len(activations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for layer in range(len(activations)):\n",
    "    curr_act = (activations[layer] >= 0).int()\n",
    "    activs_flattened = curr_act[0].permute(1, 2, 0).reshape((784, -1))\n",
    "    kmeans = KMeans(n_clusters=12, random_state=0).fit(activs_flattened.numpy())\n",
    "    kmeans_clusters = kmeans.labels_.reshape((28, 28))\n",
    "    plt.imshow(kmeans_clusters, cmap='Greys')\n",
    "    plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for layer in range(len(inputs)):\n",
    "    plt.imshow(inputs[layer][0, 0].detach().numpy(), cmap='Greys')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for layer in range(len(activations)):\n",
    "    curr_act = (activations[layer] >= 0).int()\n",
    "    activs_flattened = curr_act[1].permute(1, 2, 0).reshape((784, -1))\n",
    "    kmeans = KMeans(n_clusters=12, random_state=0).fit(activs_flattened.numpy())\n",
    "    kmeans_clusters = kmeans.labels_.reshape((28, 28))\n",
    "    plt.imshow(kmeans_clusters, cmap='Greys')\n",
    "    plt.axis('off')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(1 -A[0, 0], cmap='Greys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(kmeans_clusters, cmap='Greys')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "activs_flattened = activations[1].permute(1, 2, 0).reshape((784, -1))\n",
    "activs_flattened.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "kmeans = KMeans(n_clusters=12, random_state=0).fit(activs_flattened.numpy())\n",
    "kmeans_clusters = kmeans.labels_.reshape((28, 28))\n",
    "plt.imshow(1 -A[1, 0], cmap='Greys')\n",
    "plt.show()\n",
    "plt.imshow(kmeans_clusters, cmap='Greys')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "activs_flattened = activations[2].permute(1, 2, 0).reshape((784, -1))\n",
    "activs_flattened.shape\n",
    "kmeans = KMeans(n_clusters=12, random_state=0).fit(activs_flattened.numpy())\n",
    "kmeans_clusters = kmeans.labels_.reshape((28, 28))\n",
    "plt.imshow(1 -A[2, 0], cmap='Greys')\n",
    "plt.show()\n",
    "plt.imshow(kmeans_clusters, cmap='Greys')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "activs_flattened = activations[3].permute(1, 2, 0).reshape((784, -1))\n",
    "activs_flattened.shape\n",
    "kmeans = KMeans(n_clusters=12, random_state=0).fit(activs_flattened.numpy())\n",
    "kmeans_clusters = kmeans.labels_.reshape((28, 28))\n",
    "plt.imshow(1 -A[3, 0], cmap='Greys')\n",
    "plt.show()\n",
    "plt.imshow(kmeans_clusters, cmap='Greys')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "activs_flattened = activations[4].permute(1, 2, 0).reshape((784, -1))\n",
    "activs_flattened.shape\n",
    "kmeans = KMeans(n_clusters=12, random_state=0).fit(activs_flattened.numpy())\n",
    "kmeans_clusters = kmeans.labels_.reshape((28, 28))\n",
    "plt.imshow(1 -A[4, 0], cmap='Greys')\n",
    "plt.show()\n",
    "plt.imshow(kmeans_clusters, cmap='Greys')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
