{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data preprocessing and network architectures is modified from https://github.com/ahmedbesbes/mrnet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This is a notebook to train a model on the MRNet dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "import os\n",
    "import time\n",
    "from datetime import datetime\n",
    "import argparse\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import transforms\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from dataloader_mrnet import MRDataset\n",
    "import model_mrnet\n",
    "\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = 'acl' # choices are 'abnormal', 'acl', 'meniscus'\n",
    "plane = 'sagittal' \n",
    "epochs = 50\n",
    "lr = 1e-4\n",
    "log_every = 100\n",
    "threshold = 20 # instead of s=5, as the batch size is 1, \n",
    "# we set this to be much larger than before so that the fluctuations are reasonable\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, train_loader, epoch, num_epochs, optimizer, current_lr, log_every=100):\n",
    "    model.train()\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        model.cuda()\n",
    "\n",
    "    y_preds = []\n",
    "    y_trues = []\n",
    "    losses = []\n",
    "\n",
    "    for i, (image, label, weight) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        if torch.cuda.is_available():\n",
    "            image = image.cuda()\n",
    "            label = label.cuda()\n",
    "            weight = weight.cuda()\n",
    "        \n",
    "        label = label[0]\n",
    "        weight = weight[0]\n",
    "\n",
    "        prediction = model.forward(image.float())\n",
    "\n",
    "        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        loss_value = loss.item()\n",
    "        losses.append(loss_value)\n",
    "\n",
    "        probas = torch.sigmoid(prediction)\n",
    "\n",
    "        y_trues.append(int(label[0][1]))\n",
    "        y_preds.append(probas[0][1].item())\n",
    "\n",
    "        try:\n",
    "            auc = metrics.roc_auc_score(y_trues, y_preds)\n",
    "        except:\n",
    "            auc = 0.5\n",
    "\n",
    "        if (i % log_every == 0) & (i > 0):\n",
    "            print('''[Epoch: {0} / {1} |Single batch number : {2} / {3} ]| avg train loss {4} | train auc : {5} | lr : {6}'''.\n",
    "                  format(\n",
    "                      epoch + 1,\n",
    "                      num_epochs,\n",
    "                      i,\n",
    "                      len(train_loader),\n",
    "                      np.round(np.mean(losses), 4),\n",
    "                      np.round(auc, 4),\n",
    "                      current_lr\n",
    "                  )\n",
    "                  )\n",
    "\n",
    "    train_loss_epoch = np.round(np.mean(losses), 4)\n",
    "    train_auc_epoch = np.round(auc, 4)\n",
    "    return train_loss_epoch, train_auc_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(model, val_loader, epoch, num_epochs, current_lr, log_every=20):\n",
    "    model.eval()\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        model.cuda()\n",
    "\n",
    "    y_trues = []\n",
    "    y_preds = []\n",
    "    losses = []\n",
    "\n",
    "    for i, (image, label, weight) in enumerate(val_loader):\n",
    "\n",
    "        if torch.cuda.is_available():\n",
    "            image = image.cuda()\n",
    "            label = label.cuda()\n",
    "            weight = weight.cuda()\n",
    "\n",
    "        label = label[0]\n",
    "        weight = weight[0]\n",
    "\n",
    "        prediction = model.forward(image.float())\n",
    "\n",
    "        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)\n",
    "\n",
    "        loss_value = loss.item()\n",
    "        losses.append(loss_value)\n",
    "\n",
    "        probas = torch.sigmoid(prediction)\n",
    "\n",
    "        y_trues.append(int(label[0][1]))\n",
    "        y_preds.append(probas[0][1].item())\n",
    "\n",
    "        try:\n",
    "            auc = metrics.roc_auc_score(y_trues, y_preds)\n",
    "        except:\n",
    "            auc = 0.5\n",
    "\n",
    "        if (i % log_every == 0) & (i > 0):\n",
    "            print('''[Epoch: {0} / {1} |Single batch number : {2} / {3} ] | avg val loss {4} | val auc : {5} | lr : {6}'''.\n",
    "                  format(\n",
    "                      epoch + 1,\n",
    "                      num_epochs,\n",
    "                      i,\n",
    "                      len(val_loader),\n",
    "                      np.round(np.mean(losses), 4),\n",
    "                      np.round(auc, 4),\n",
    "                      current_lr\n",
    "                  )\n",
    "                  )\n",
    "\n",
    "    val_loss_epoch = np.round(np.mean(losses), 4)\n",
    "    val_auc_epoch = np.round(auc, 4)\n",
    "    return val_loss_epoch, val_auc_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_lr(optimizer):\n",
    "    for param_group in optimizer.param_groups:\n",
    "        return param_group['lr']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_GD(model, train_loader):\n",
    "    \"\"\"the function for computing the gradient disparity\n",
    "    this functions gives the following output:\n",
    "    avg_grad_dis: the avg gradient disparity between pairs of samples of the dataset\n",
    "    \"\"\"\n",
    "    lr = 0.01\n",
    "    opt = optim.SGD(model.parameters(), lr=lr)\n",
    "    it = iter(train_loader)\n",
    "    Ls = []\n",
    "    # because the batch size is 1, we manually find loss std\n",
    "    for i in range(40):\n",
    "        image, label, weight = next(it)\n",
    "        image = Variable(image.cuda(), requires_grad=True)\n",
    "        label = Variable(label).cuda()\n",
    "        weight = Variable(weight).cuda()\n",
    "\n",
    "        label = label[0]\n",
    "        weight = weight[0]\n",
    "        opt.zero_grad()\n",
    "\n",
    "        criterion = torch.nn.BCEWithLogitsLoss(weight=weight)\n",
    "\n",
    "        prediction = model.forward(image.float())\n",
    "\n",
    "        loss = criterion(prediction, label).item()\n",
    "        Ls.append(loss)\n",
    "        \n",
    "    \n",
    "    loss_std = np.std(np.array(Ls), axis=0)\n",
    "    # set model in training mode (need this because of dropout)\n",
    "    model.train() \n",
    "    cnt = 0\n",
    "    avg_grad_dis = 0\n",
    "    Grads = []\n",
    "    it = iter(train_loader)\n",
    "    for i in range(threshold):\n",
    "        image, label, weight = next(it)\n",
    "        image = Variable(image.cuda(), requires_grad=True)\n",
    "        label = Variable(label).cuda()\n",
    "        weight = Variable(weight).cuda()\n",
    "\n",
    "        label = label[0]\n",
    "        weight = weight[0]\n",
    "        opt.zero_grad()\n",
    "            \n",
    "        criterion = torch.nn.BCEWithLogitsLoss(weight=weight)\n",
    "\n",
    "        prediction = model.forward(image.float())\n",
    "\n",
    "        loss = criterion(prediction, label)           \n",
    "        loss1_s = loss/loss_std\n",
    "        \n",
    "        loss1_s.backward(retain_graph=True)\n",
    "        grads1_s = []\n",
    "        for name, param in model.named_parameters():\n",
    "            if param.grad != None:\n",
    "                grads1_s.append(param.grad.view(-1))\n",
    "        grads1_s = torch.cat(grads1_s)\n",
    "        Grads.append(grads1_s.data.cpu().numpy())\n",
    "        cnt += 1\n",
    "    \n",
    "    Grads = np.array(Grads)\n",
    "    cnt2 = 0\n",
    "    for i in range(cnt):\n",
    "        for j in range(cnt):\n",
    "            if i < j:\n",
    "                grads1 = Grads[i]\n",
    "                grads2 = Grads[j]\n",
    "                avg_grad_dis += np.linalg.norm(grads1-grads2)\n",
    "                cnt2 += 1\n",
    "    \n",
    "        \n",
    "    avg_grad_dis /= cnt2\n",
    "    return avg_grad_dis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_avgs = 5\n",
    "num_epochs = 100\n",
    "\n",
    "# these lists store the results for multiple runs\n",
    "tLosses_avg = []\n",
    "tAUCs_avg = []\n",
    "vLosses_avg = []\n",
    "vAUCs_avg = []\n",
    "GD_avg = []\n",
    "\n",
    "for i in range(num_avgs):\n",
    "    train_dataset = MRDataset('./data/', task,\n",
    "                              plane, train=True)\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        train_dataset, batch_size=1, shuffle=True, drop_last=False)\n",
    "    validation_dataset = MRDataset(\n",
    "        './data/', task, plane, train=False)\n",
    "    validation_loader = torch.utils.data.DataLoader(\n",
    "        validation_dataset, batch_size=1, shuffle=-True, drop_last=False)\n",
    "\n",
    "    mrnet = model_mrnet.MRNet()\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        mrnet = mrnet.cuda()\n",
    "    lr = 1e-4\n",
    "    optimizer = optim.SGD(mrnet.parameters(), lr=lr)\n",
    "\n",
    "    tLosses = []\n",
    "    tAUCs = []\n",
    "    vLosses = []\n",
    "    vAUCs = []\n",
    "    GD = []\n",
    "    t_start_training = time.time()\n",
    "    for epoch in range(num_epochs):\n",
    "        current_lr = lr\n",
    "        t_start = time.time()\n",
    "        \n",
    "        gd = find_GD(mrnet, train_loader)\n",
    "        train_loss, train_auc = train_model(\n",
    "            mrnet, train_loader, epoch, num_epochs, optimizer, current_lr, log_every)\n",
    "        val_loss, val_auc = evaluate_model(\n",
    "            mrnet, validation_loader, epoch, num_epochs,  current_lr)\n",
    "\n",
    "        tLosses.append(train_loss)\n",
    "        tAUCs.append(train_auc)\n",
    "        vLosses.append(val_loss)\n",
    "        vAUCs.append(val_auc)\n",
    "        GD.append(gd)\n",
    "\n",
    "        t_end = time.time()\n",
    "        delta = t_end - t_start\n",
    "\n",
    "        print(\"train loss : {0} | train auc {1} | val loss {2} | val auc {3} | pac {4} | elapsed time {5} s\".format(\n",
    "            train_loss, train_auc, val_loss, val_auc, gd, delta))\n",
    "        print('-' * 30)\n",
    "\n",
    "    t_end_training = time.time()\n",
    "    print('training took {%s - %s} s'% (t_end_training, t_start_training))\n",
    "    tLosses_avg.append(tLosses)\n",
    "    tAUCs_avg.append(tAUCs)\n",
    "    vLosses_avg.append(vLosses)\n",
    "    vAUCs_avg.append(vAUCs)\n",
    "    GD_avg.append(GD)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# to save the results in a file\n",
    "List = [vLosses_avg, tLosses_avg, vAUCs_avg, tAUCs_avg, GD_avg]\n",
    "\n",
    "# rename the file to have the results saved on\n",
    "with open('temp.data', 'wb') as filehandle:\n",
    "    # store the data as binary data stream\n",
    "    for ls in List:\n",
    "        pickle.dump(ls, filehandle) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
