{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Interpretable Deep Clustering\n",
    "\n",
    "This example presents how to train a clustering model on MNIST dataset\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 0:\n",
    "Assuming you have Python 3.8 or 3.10 (tested on both), install the requirements by running the next lines:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 1:\n",
    "Import some dependencies\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "from omegaconf import OmegaConf\n",
    "import numpy as np\n",
    "from pytorch_lightning import Trainer, seed_everything\n",
    "import os\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "from pytorch_lightning.callbacks import LearningRateMonitor\n",
    "from idc_mnist import MNISTClustering"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 2:\n",
    "Modify the default configuration dictionary:\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "config = OmegaConf.create(dict(\n",
    "        # GatingNet\n",
    "        sigma=0.5,\n",
    "        gates_hidden_dim=784,\n",
    "        global_gates_reg_lambda=10,\n",
    "        local_gates_reg_lambda=100,\n",
    "        start_global_gates_training_on_epoch=100,\n",
    "\n",
    "        # Autoencoder:\n",
    "        ae_pretrain_epochs=100,\n",
    "        ae_non_gated_epochs=10,\n",
    "        mask_percentage=0.9,\n",
    "        latent_noise_std=0.01,\n",
    "\n",
    "        # MCRR:\n",
    "        gamma=4,\n",
    "        eps=0.1,\n",
    "\n",
    "        # Dataset:\n",
    "        dataset=\"MNIST\",\n",
    "        data_dir=\".\",\n",
    "        input_dim=784,\n",
    "        n_clusters=10,\n",
    "        batch_size=256,\n",
    "        repitions=20,\n",
    "        tau=100,\n",
    "\n",
    "        trainer=dict(\n",
    "            gpus=1,\n",
    "            auto_select_gpus=True,\n",
    "            max_epochs=700,\n",
    "            deterministic=True,\n",
    "            logger=True,\n",
    "            log_every_n_steps=20,\n",
    "            check_val_every_n_epoch=10,\n",
    "            enable_checkpointing=False,\n",
    "        )\n",
    "    ))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 3:\n",
    "Start training!"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "torch.use_deterministic_algorithms(True)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "original_cfg = config.copy()\n",
    "seed_everything(777)\n",
    "np.random.seed(777)\n",
    "if not os.path.exists(config.dataset):\n",
    "    os.makedirs(config.dataset)\n",
    "model = MNISTClustering(config)\n",
    "logger = TensorBoardLogger(config.dataset, name=config.dataset, log_graph=False)\n",
    "trainer = Trainer(**config.trainer, callbacks=[LearningRateMonitor(logging_interval='step')])\n",
    "trainer.logger = logger\n",
    "trainer.fit(model)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "During the training we export both `sparse_model_best.pth` with best accuracy and `sparse_model_last.pth` which is the latest one."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
