{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## Imports"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "1ae2c8fd3734c812"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from healnet.models import HealNet\n",
    "from healnet.etl import MMDataset\n",
    "import torch\n",
    "import einops\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from typing import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "787ff0241bf48627"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Synthetic modalities\n",
    "\n",
    "We instantiate a synthetic multimodal dataset for demo purposes. "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "5de637307d506881"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "n = 1000 # number of samples\n",
    "b = 4 # batch size\n",
    "img_c = 3 # image channels\n",
    "tab_c = 1 # tabular channels\n",
    "tab_d = 5000 # tabular features\n",
    "h = 512 # image height\n",
    "w = 512 # image width\n",
    "n_classes = 4 # classification\n",
    "\n",
    "tab_tensor = torch.rand(size=(n, tab_c, tab_d)) # assume 5k tabular features\n",
    "img_tensor = torch.rand(size=(n, img_c, h, w)) # c h w\n",
    "\n",
    "\n",
    "# derive a target\n",
    "target = torch.rand(size=(n,))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "233819ff1c11e04a"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "data = MMDataset([tab_tensor, img_tensor], target)\n",
    "train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split\n",
    "\n",
    "loader_args = {\n",
    "    \"shuffle\": True, \n",
    "    \"num_workers\": 8, \n",
    "    \"pin_memory\": True, \n",
    "    \"multiprocessing_context\": \"fork\", \n",
    "    \"persistent_workers\": True, \n",
    "}\n",
    "\n",
    "train_loader = DataLoader(train, **loader_args)\n",
    "val_loader = DataLoader(val, **loader_args)\n",
    "test_loader = DataLoader(test, **loader_args)\n",
    "# example use"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "1df31a7b4a69e7aa"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# example use\n",
    "[tab_sample, img_sample], target = data[0]\n",
    "\n",
    "# emulate batch dimension\n",
    "tab_sample = einops.repeat(tab_sample, 'c d -> b c d', b=1)\n",
    "img_sample = einops.repeat(img_sample, 'c h w -> b c (h w)', b=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "45ab1e79c8bc3a21"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "img_sample.shape"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "38724a29f4142f8e"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model = HealNet(\n",
    "            modalities=2, \n",
    "            input_channels=[tab_c, img_c], \n",
    "            input_axes=[1, 1], # channel axes (0-indexed)\n",
    "            num_classes = n_classes,  \n",
    "        )"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "504d15c029cf68e0"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# forward pass\n",
    "model([tab_sample, img_sample])"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d9e77fe00f07c904"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
