{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "import sys\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import monai\n",
    "from monai.apps import download_and_extract\n",
    "from monai.config import print_config\n",
    "from monai.data import DataLoader, ImageDataset\n",
    "from monai.transforms import (\n",
    "    EnsureChannelFirst,\n",
    "    Compose,\n",
    "    RandRotate90,\n",
    "    Resize,\n",
    "    ScaleIntensity,\n",
    ")\n",
    "\n",
    "from monai.networks.nets import  DenseNet121\n",
    "from pe.datatasets.data_loaders.rsna_pe_simple import NiftiDataset\n",
    "from torch.utils.data import  DataLoader\n",
    "import torch \n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Subset, random_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 1.3.0\n",
      "Numpy version: 1.25.2\n",
      "Pytorch version: 2.0.1+cu117\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e\n",
      "MONAI __file__: /share/pi/nigam/alejandro/envs/pe_env/lib/python3.9/site-packages/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "ITK version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "Nibabel version: 5.1.0\n",
      "scikit-image version: 0.21.0\n",
      "scipy version: 1.9.3\n",
      "Pillow version: 10.0.0\n",
      "Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "TorchVision version: 0.15.2+cu117\n",
      "tqdm version: 4.66.1\n",
      "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "psutil version: 5.9.5\n",
      "pandas version: 2.1.0\n",
      "einops version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "transformers version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "clearml version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pin_memory = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "print_config()\n",
    "\n",
    "# Define transforms\n",
    "train_transforms = Compose([ScaleIntensity(), Resize((224, 224, 200)), RandRotate90()])\n",
    "\n",
    "# Define Nifti dataset and data loader for the new dataset\n",
    "csv_file = '/share/pi/nigam/data/RSNAPE/simplified_labels/train.csv'\n",
    "root_dir = '/share/pi/nigam/data/RSNAPE/nifti_crop/'\n",
    "\n",
    "#nifti_dataset = NiftiDataset(csv_file=csv_file, root_dir=root_dir,transform=train_transforms)\n",
    "nifti_dataset = NiftiDataset(csv_file=csv_file, root_dir=root_dir)\n",
    "loader        = DataLoader(nifti_dataset, batch_size=1, shuffle=True)\n",
    "\n",
    "dataset_size = len(loader.dataset)\n",
    "test_size    = int(0.2 * dataset_size)\n",
    "train_size   = dataset_size - test_size\n",
    "\n",
    "# Create indices for train and test subsets\n",
    "indices = list(range(dataset_size))\n",
    "train_indices, test_indices = indices[:train_size], indices[train_size:]\n",
    "\n",
    "# Create Subset data loaders for train and test\n",
    "train_loader = DataLoader(Subset(loader.dataset, train_indices), batch_size=1, shuffle=True)\n",
    "test_loader  = DataLoader(Subset(loader.dataset, test_indices), batch_size=1, shuffle=False)\n",
    "\n",
    "# Create a model (Swin or DenseNet)\n",
    "model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device)\n",
    "\n",
    "# Create loss function and optimizer\n",
    "loss_function = torch.nn.CrossEntropyLoss()\n",
    "optimizer     = torch.optim.Adam(model.parameters(), 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch_data in train_loader:\n",
    "    inputs, labels = batch_data[0].to(device), batch_data[1].to(device)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs_a = torch.unsqueeze(train_transforms(inputs),dim=0)\n",
    "outputs = model(inputs_a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Target size (torch.Size([1])) must be the same as input size (torch.Size([1, 1]))",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m/share/pi/nigam/alejandro/Repositories/pe/examples/train.ipynb Cell 7\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Blogin.carina.stanford.edu/share/pi/nigam/alejandro/Repositories/pe/examples/train.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m F\u001b[39m.\u001b[39;49mbinary_cross_entropy_with_logits(outputs,labels)\n",
      "File \u001b[0;32m/share/pi/nigam/alejandro/envs/pe_env/lib/python3.9/site-packages/torch/nn/functional.py:3146\u001b[0m, in \u001b[0;36mbinary_cross_entropy_with_logits\u001b[0;34m(input, target, weight, size_average, reduce, reduction, pos_weight)\u001b[0m\n\u001b[1;32m   3110\u001b[0m \u001b[39m\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\"\"Function that measures Binary Cross Entropy between target and input\u001b[39;00m\n\u001b[1;32m   3111\u001b[0m \u001b[39mlogits.\u001b[39;00m\n\u001b[1;32m   3112\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   3143\u001b[0m \u001b[39m     >>> loss.backward()\u001b[39;00m\n\u001b[1;32m   3144\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   3145\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_variadic(\u001b[39minput\u001b[39m, target, weight, pos_weight):\n\u001b[0;32m-> 3146\u001b[0m     \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m   3147\u001b[0m         binary_cross_entropy_with_logits,\n\u001b[1;32m   3148\u001b[0m         (\u001b[39minput\u001b[39;49m, target, weight, pos_weight),\n\u001b[1;32m   3149\u001b[0m         \u001b[39minput\u001b[39;49m,\n\u001b[1;32m   3150\u001b[0m         target,\n\u001b[1;32m   3151\u001b[0m         weight\u001b[39m=\u001b[39;49mweight,\n\u001b[1;32m   3152\u001b[0m         size_average\u001b[39m=\u001b[39;49msize_average,\n\u001b[1;32m   3153\u001b[0m         reduce\u001b[39m=\u001b[39;49mreduce,\n\u001b[1;32m   3154\u001b[0m         reduction\u001b[39m=\u001b[39;49mreduction,\n\u001b[1;32m   3155\u001b[0m         pos_weight\u001b[39m=\u001b[39;49mpos_weight,\n\u001b[1;32m   3156\u001b[0m     )\n\u001b[1;32m   3157\u001b[0m \u001b[39mif\u001b[39;00m size_average \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m reduce \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   3158\u001b[0m     reduction_enum \u001b[39m=\u001b[39m _Reduction\u001b[39m.\u001b[39mlegacy_get_enum(size_average, reduce)\n",
      "File \u001b[0;32m/share/pi/nigam/alejandro/envs/pe_env/lib/python3.9/site-packages/torch/overrides.py:1551\u001b[0m, in \u001b[0;36mhandle_torch_function\u001b[0;34m(public_api, relevant_args, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1545\u001b[0m     warnings\u001b[39m.\u001b[39mwarn(\u001b[39m\"\u001b[39m\u001b[39mDefining your `__torch_function__ as a plain method is deprecated and \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m   1546\u001b[0m                   \u001b[39m\"\u001b[39m\u001b[39mwill be an error in future, please define it as a classmethod.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m   1547\u001b[0m                   \u001b[39mDeprecationWarning\u001b[39;00m)\n\u001b[1;32m   1549\u001b[0m \u001b[39m# Use `public_api` instead of `implementation` so __torch_function__\u001b[39;00m\n\u001b[1;32m   1550\u001b[0m \u001b[39m# implementations can do equality/identity comparisons.\u001b[39;00m\n\u001b[0;32m-> 1551\u001b[0m result \u001b[39m=\u001b[39m torch_func_method(public_api, types, args, kwargs)\n\u001b[1;32m   1553\u001b[0m \u001b[39mif\u001b[39;00m result \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNotImplemented\u001b[39m:\n\u001b[1;32m   1554\u001b[0m     \u001b[39mreturn\u001b[39;00m result\n",
      "File \u001b[0;32m/share/pi/nigam/alejandro/envs/pe_env/lib/python3.9/site-packages/monai/data/meta_tensor.py:282\u001b[0m, in \u001b[0;36mMetaTensor.__torch_function__\u001b[0;34m(cls, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m    280\u001b[0m \u001b[39mif\u001b[39;00m kwargs \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    281\u001b[0m     kwargs \u001b[39m=\u001b[39m {}\n\u001b[0;32m--> 282\u001b[0m ret \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m__torch_function__(func, types, args, kwargs)\n\u001b[1;32m    283\u001b[0m \u001b[39m# if `out` has been used as argument, metadata is not copied, nothing to do.\u001b[39;00m\n\u001b[1;32m    284\u001b[0m \u001b[39m# if \"out\" in kwargs:\u001b[39;00m\n\u001b[1;32m    285\u001b[0m \u001b[39m#     return ret\u001b[39;00m\n\u001b[1;32m    286\u001b[0m \u001b[39mif\u001b[39;00m _not_requiring_metadata(ret):\n",
      "File \u001b[0;32m/share/pi/nigam/alejandro/envs/pe_env/lib/python3.9/site-packages/torch/_tensor.py:1295\u001b[0m, in \u001b[0;36mTensor.__torch_function__\u001b[0;34m(cls, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1292\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mNotImplemented\u001b[39m\n\u001b[1;32m   1294\u001b[0m \u001b[39mwith\u001b[39;00m _C\u001b[39m.\u001b[39mDisableTorchFunctionSubclass():\n\u001b[0;32m-> 1295\u001b[0m     ret \u001b[39m=\u001b[39m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1296\u001b[0m     \u001b[39mif\u001b[39;00m func \u001b[39min\u001b[39;00m get_default_nowrap_functions():\n\u001b[1;32m   1297\u001b[0m         \u001b[39mreturn\u001b[39;00m ret\n",
      "File \u001b[0;32m/share/pi/nigam/alejandro/envs/pe_env/lib/python3.9/site-packages/torch/nn/functional.py:3163\u001b[0m, in \u001b[0;36mbinary_cross_entropy_with_logits\u001b[0;34m(input, target, weight, size_average, reduce, reduction, pos_weight)\u001b[0m\n\u001b[1;32m   3160\u001b[0m     reduction_enum \u001b[39m=\u001b[39m _Reduction\u001b[39m.\u001b[39mget_enum(reduction)\n\u001b[1;32m   3162\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (target\u001b[39m.\u001b[39msize() \u001b[39m==\u001b[39m \u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize()):\n\u001b[0;32m-> 3163\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTarget size (\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m) must be the same as input size (\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(target\u001b[39m.\u001b[39msize(), \u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize()))\n\u001b[1;32m   3165\u001b[0m \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39mbinary_cross_entropy_with_logits(\u001b[39minput\u001b[39m, target, weight, pos_weight, reduction_enum)\n",
      "\u001b[0;31mValueError\u001b[0m: Target size (torch.Size([1])) must be the same as input size (torch.Size([1, 1]))"
     ]
    }
   ],
   "source": [
    "F.binary_cross_entropy_with_logits(outputs,labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "metatensor([[0.3415]], grad_fn=<AliasBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_function( ,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------\n",
      "epoch 1/5\n",
      "1/387, train_loss: -0.0000\n",
      "2/387, train_loss: -0.0000\n",
      "3/387, train_loss: -0.0000\n",
      "4/387, train_loss: -0.0000\n",
      "5/387, train_loss: -0.0000\n",
      "6/387, train_loss: -0.0000\n",
      "7/387, train_loss: -0.0000\n",
      "8/387, train_loss: -0.0000\n",
      "9/387, train_loss: -0.0000\n",
      "10/387, train_loss: -0.0000\n",
      "11/387, train_loss: -0.0000\n",
      "12/387, train_loss: -0.0000\n",
      "13/387, train_loss: -0.0000\n",
      "14/387, train_loss: -0.0000\n",
      "15/387, train_loss: -0.0000\n",
      "16/387, train_loss: -0.0000\n",
      "17/387, train_loss: -0.0000\n",
      "18/387, train_loss: -0.0000\n",
      "19/387, train_loss: -0.0000\n",
      "20/387, train_loss: -0.0000\n",
      "21/387, train_loss: -0.0000\n"
     ]
    }
   ],
   "source": [
    "# start a typical PyTorch training\n",
    "val_interval      = 2\n",
    "best_metric       = -1\n",
    "best_metric_epoch = -1\n",
    "epoch_loss_values = []\n",
    "metric_values     = []\n",
    "max_epochs        = 5\n",
    "\n",
    "for epoch in range(max_epochs):\n",
    "    print(\"-\" * 10)\n",
    "    print(f\"epoch {epoch + 1}/{max_epochs}\")\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    step = 0\n",
    "\n",
    "    for batch_data in train_loader:\n",
    "        step += 1\n",
    "        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)\n",
    "        inputs         = torch.unsqueeze(train_transforms(inputs),dim=0)\n",
    "        optimizer.zero_grad()\n",
    "        # import pdb;pdb.set_trace()\n",
    "        outputs  = model(inputs)\n",
    "        loss     = loss_function(outputs,  labels.reshape_as(outputs).float())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "        epoch_len   = train_size // train_loader.batch_size\n",
    "        print(f\"{step}/{epoch_len}, train_loss: {loss.item():.4f}\")\n",
    "    \n",
    "\n",
    "    epoch_loss /= step\n",
    "    epoch_loss_values.append(epoch_loss)\n",
    "    print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n",
    "\n",
    "    if (epoch + 1) % val_interval == 0:\n",
    "        model.eval()\n",
    "\n",
    "        num_correct = 0.0\n",
    "        metric_count = 0\n",
    "        for val_data in test_loader:\n",
    "            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)\n",
    "            with torch.no_grad():\n",
    "                val_outputs = model(val_images)\n",
    "                value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))\n",
    "                metric_count += len(value)\n",
    "                num_correct += value.sum().item()\n",
    "\n",
    "        metric = num_correct / metric_count\n",
    "        metric_values.append(metric)\n",
    "\n",
    "        if metric > best_metric:\n",
    "            best_metric = metric\n",
    "            best_metric_epoch = epoch + 1\n",
    "            torch.save(model.state_dict(), \"best_metric_model_classification3d_array.pth\")\n",
    "            print(\"saved new best metric model\")\n",
    "\n",
    "        print(f\"Current epoch: {epoch+1} current accuracy: {metric:.4f} \")\n",
    "        print(f\"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}\")\n",
    "        \n",
    "\n",
    "print(f\"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pe_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
