{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CUDA is available:  True\n",
      "GPU device name:  Tesla P100-PCIE-16GB\n",
      "Current work directory: /scratch/forest/attention-conv/notebooks\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import gc\n",
    "import datetime\n",
    "import sys\n",
    "import time\n",
    "from pathlib import Path            \n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "\n",
    "sys.path.append('../src')\n",
    "sys.path.append('../src/models/DynamicRouting')\n",
    "import models\n",
    "from utils import load_model, handle_labels\n",
    "\n",
    "print(\"CUDA is available: \", torch.cuda.is_available())\n",
    "print(\"GPU device name: \", torch.cuda.get_device_name(0))\n",
    "print(\"Current work directory: %s\" % os.getcwd())\n",
    "\n",
    "\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "# model = models.KAUNet_e(in_ch=4, out_ch=5, att_mh=1, att_sm=10, att_ks=5, att_two_w=False).to(device)\n",
    "# model = models.UNet(in_ch=4, out_ch=5).to(device)\n",
    "# model = models.UNet_feature_supervised(in_ch=4, out_ch=5)\n",
    "# model = models.UNet_edge_implicit(in_ch=4, out_ch=5).to(device)\n",
    "# model = models.DeepLab(backbone=\"resnet\", in_ch=4, num_classes=5).to(device)\n",
    "# model = models.DeepLab_att(in_ch=4, num_classes=5, backbone=\"resnet-modified-34\").to(device)\n",
    "# model = models.Dynamic_C().to(device)\n",
    "model = models.DeepLab_101_nested_fixed_gate_5layers_decatt(in_ch=4, num_classes=5).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.rand(4,4,256,256).to(device)\n",
    "y = model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model contains 138.15M parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "n_params = count_parameters(model)\n",
    "print(f\"Model contains {n_params/1e6:.2f}M parameters\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-forest-segmentation] *",
   "language": "python",
   "name": "conda-env-.conda-forest-segmentation-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
