{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "language_info": {
      "name": "python"
    },
    "orig_nbformat": 4,
    "colab": {
      "name": "cords_SSL_CIFAR10_VAT_RETRIEVE_Dataloader_Example.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Integrating subset selection dataloaders with custom SSL training loop\n",
        "\n",
        "In this tutorial, we will look at an example showing how to integrate RETRIEVEDataloader with custom SSL training loop"
      ],
      "metadata": {
        "id": "wCu1m4TTeCyb"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nRBrJb8I_vUv"
      },
      "source": [
        "### Cloning CORDS repository"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x35Mfc-RnKkX",
        "outputId": "fab97f07-9d5f-4a58-aa68-3842bd36a2df"
      },
      "source": [
        "!git clone https://github.com/decile-team/cords.git\n",
        "%cd cords/\n",
        "%ls"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'cords'...\n",
            "remote: Enumerating objects: 3920, done.\u001b[K\n",
            "remote: Counting objects: 100% (2542/2542), done.\u001b[K\n",
            "remote: Compressing objects: 100% (1155/1155), done.\u001b[K\n",
            "remote: Total 3920 (delta 1654), reused 2178 (delta 1349), pack-reused 1378\u001b[K\n",
            "Receiving objects: 100% (3920/3920), 54.62 MiB | 12.55 MiB/s, done.\n",
            "Resolving deltas: 100% (2391/2391), done.\n",
            "/content/cords\n",
            "\u001b[0m\u001b[01;34mbenchmarks\u001b[0m/  \u001b[01;34mdocs\u001b[0m/        README.md      \u001b[01;34mtests\u001b[0m/        train_ssl.py\n",
            "\u001b[01;34mconfigs\u001b[0m/     \u001b[01;34mexamples\u001b[0m/    \u001b[01;34mrequirements\u001b[0m/  train_hpo.py\n",
            "\u001b[01;34mcords\u001b[0m/       LICENSE.txt  setup.py       train_sl.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAA3K0cVnyd9"
      },
      "source": [
        "### Install prerequisite libraries of CORDS"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6CXZ4L1ynmcp",
        "outputId": "16f0ec69-4e83-4cec-922b-73b9483c45e8"
      },
      "source": [
        "!pip install dotmap\n",
        "!pip install apricot-select\n",
        "!pip install ray[default]\n",
        "!pip install ray[tune]\n",
        "!pip install datasets"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting dotmap\n",
            "  Downloading dotmap-1.3.26-py3-none-any.whl (11 kB)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzyCsbnJn3_L"
      },
      "source": [
        "###Import necessary libraries"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GazQv21pjIKd"
      },
      "source": [
        "import logging\n",
        "import numpy, random, time, json, copy\n",
        "import numpy as np\n",
        "import os.path as osp\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, Subset\n",
        "from cords.utils.data.data_utils import WeightedSubset\n",
        "from cords.utils.models import WideResNet, ShakeNet, CNN13, CNN\n",
        "from cords.utils.data.datasets.SSL import utils as dataset_utils\n",
        "from cords.selectionstrategies.helpers.ssl_lib.algs import utils as alg_utils\n",
        "from cords.utils.models import utils as model_utils\n",
        "from cords.utils.data.datasets.SSL import gen_dataset\n",
        "from cords.selectionstrategies.helpers.ssl_lib.param_scheduler import scheduler\n",
        "from cords.selectionstrategies.helpers.ssl_lib.misc.meter import Meter\n",
        "from cords.utils.config_utils import load_config_data\n",
        "import time\n",
        "import os\n",
        "import sys"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uYOP5EWU_UD7"
      },
      "source": [
        "###Get logger object for logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LZ8H3_Hx_TtF"
      },
      "source": [
        "def __get_logger(results_dir):\n",
        "  os.makedirs(results_dir, exist_ok=True)\n",
        "  # setup logger\n",
        "  plain_formatter = logging.Formatter(\"[%(asctime)s] %(name)s %(levelname)s: %(message)s\",\n",
        "                                      datefmt=\"%m/%d %H:%M:%S\")\n",
        "  logger = logging.getLogger(__name__)\n",
        "  logger.setLevel(logging.INFO)\n",
        "  s_handler = logging.StreamHandler(stream=sys.stdout)\n",
        "  s_handler.setFormatter(plain_formatter)\n",
        "  s_handler.setLevel(logging.INFO)\n",
        "  logger.addHandler(s_handler)\n",
        "  f_handler = logging.FileHandler(os.path.join(results_dir, \"results.log\"))\n",
        "  f_handler.setFormatter(plain_formatter)\n",
        "  f_handler.setLevel(logging.DEBUG)\n",
        "  logger.addHandler(f_handler)\n",
        "  logger.propagate = False\n",
        "  return logger\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Defining the results directory and getting the results logger object"
      ],
      "metadata": {
        "id": "bwUZJRQxBLyx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "results_dir = 'results/'\n",
        "logger = __get_logger(results_dir)\n"
      ],
      "metadata": {
        "id": "NUSLouwCBLRB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MyMJdoeqok49"
      },
      "source": [
        "### Loading configuration file with predefined arguments:\n",
        "\n",
        "We have a set of predefined configuration files added to CORDS for SSL under cords/configs/SSL/ which can be used directly by loading them as a dotmap object. \n",
        "\n",
        "An example of predefined configuration for CIFAR10 using VAT as SSL algorithm and RETRIEVE as subset selection strategy can be found below:\n",
        "\n",
        "```Python3\n",
        "\n",
        "# Learning setting\n",
        "# Learning setting\n",
        "config = dict(setting=\"SSL\",\n",
        "              dataset=dict(name=\"cifar10\",\n",
        "                           root=\"../data\",\n",
        "                           feature=\"dss\",\n",
        "                           type=\"pre-defined\",\n",
        "                           num_labels=4000,\n",
        "                           val_ratio=0.1,\n",
        "                           ood_ratio=0.5,\n",
        "                           random_split=False,\n",
        "                           whiten=False,\n",
        "                           zca=True,\n",
        "                           labeled_aug='WA',\n",
        "                           unlabeled_aug='WA',\n",
        "                           wa='t.t.f',\n",
        "                           strong_aug=False),\n",
        "\n",
        "              dataloader=dict(shuffle=True,\n",
        "                              pin_memory=True,\n",
        "                              num_workers=8,\n",
        "                              l_batch_size=50,\n",
        "                              ul_batch_size=50),\n",
        "\n",
        "              model=dict(architecture='wrn',\n",
        "                         type='pre-defined',\n",
        "                         numclasses=10),\n",
        "\n",
        "              ckpt=dict(is_load=False,\n",
        "                        is_save=True,\n",
        "                        checkpoint_model='model.ckpt',\n",
        "                        checkpoint_optimizer='optimizer.ckpt',\n",
        "                        start_iter=None,\n",
        "                        checkpoint=10000),\n",
        "\n",
        "              loss=dict(type='CrossEntropyLoss',\n",
        "                        use_sigmoid=False),\n",
        "\n",
        "              optimizer=dict(type=\"sgd\",\n",
        "                             momentum=0.9,\n",
        "                             lr=0.03,\n",
        "                             weight_decay=0,\n",
        "                             nesterov=True,\n",
        "                             tsa=False,\n",
        "                             tsa_schedule='linear'),\n",
        "\n",
        "              scheduler=dict(lr_decay=\"cos\",\n",
        "                             warmup_iter=0),\n",
        "\n",
        "              ssl_args=dict(alg='vat',\n",
        "                            coef=0.3,\n",
        "                            ema_teacher=False,\n",
        "                            ema_teacher_warmup=False,\n",
        "                            ema_teacher_factor=0.999,\n",
        "                            ema_apply_wd=False,\n",
        "                            em=0,\n",
        "                            threshold=None,\n",
        "                            sharpen=None,\n",
        "                            temp_softmax=None,\n",
        "                            consis='ce',\n",
        "                            eps=6,\n",
        "                            xi=1e-6,\n",
        "                            vat_iter=1\n",
        "                            ),\n",
        "\n",
        "              ssl_eval_args=dict(weight_average=False,\n",
        "                                 wa_ema_factor=0.999,\n",
        "                                 wa_apply_wd=False),\n",
        "\n",
        "              dss_args=dict(type=\"RETRIEVE-Warm\",\n",
        "                            fraction=0.1,\n",
        "                            select_every=20,\n",
        "                            kappa=0.5,\n",
        "                            linear_layer=False,\n",
        "                            selection_type='Supervised',\n",
        "                            greedy='Stochastic',\n",
        "                            valid=True),\n",
        "\n",
        "              train_args=dict(iteration=500000,\n",
        "                              max_iter=-1,\n",
        "                              device=\"cuda\",\n",
        "                              results_dir='results/',\n",
        "                              disp=256,\n",
        "                              seed=96)\n",
        "              )\n",
        "\n",
        "```\n",
        "\n",
        "Please find a detailed documentation explaining the available configuration parameters in the following readthedocs [page]()\n",
        "\n",
        "***Loading the predefined configuration file directly using the load_config_data function in CORDS***"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from cords.utils.config_utils import load_config_data\n",
        "cfg = load_config_data('/content/cords/configs/SSL/config_retrieve-warm_vat_cifar10.py')"
      ],
      "metadata": {
        "id": "vozeGsg3CenF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Loading the CIFAR10 dataset for SSL\n",
        "\n",
        "Since CIFAR10 dataset is a predefined dataset in CORDS repository for SSL. You can use the gen_dataset function in cords/utils/data/datasets/SSL/builder.py for loading the CIFAR10 dataset.\n",
        "\n",
        "**Input parameters of gen_dataset function:**\n",
        "\n",
        "Parameters\n",
        "-----------\n",
        "    root: str\n",
        "        root directory in which data is present or needs to be downloaded\n",
        "    dataset: str\n",
        "        dataset name,\n",
        "        Existing dataset choices: ['cifar10', 'cifar100', 'svhn', 'stl10', 'cifarOOD', 'mnistOOD', 'cifarImbalance']\n",
        "    validation_split: bool\n",
        "        if True, return validation loader.\n",
        "        We use 10% random split of training data as validation data\n",
        "    cfg: argparse.Namespace or dict\n",
        "        Dictionary containing necessary arguments for generating the dataset\n",
        "    logger: logging.Logger\n",
        "        Logger class for logging the information\n"
      ],
      "metadata": {
        "id": "b54OTNMpBqm7"
      }
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rkjRkzs2olSD"
      },
      "source": [
        "lt_data, ult_data, test_data, num_classes, img_size = gen_dataset('data/', 'cifar10',\n",
        "                                                                  False, cfg, logger)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WGI8vbIF1IOS"
      },
      "source": [
        "###Defining Model\n",
        "\n",
        "CORDS has a set of predefined models bulit in utils folder. You can import them directly by passing on the corresponding set of rquired arguments for the model.\n",
        "\n",
        "In this notebook, we are going to use a WideResNet model that takes in the following arguments:\n",
        "\n",
        "```\n",
        "WideResNet Parameters\n",
        "-----------\n",
        "  num_classes: int\n",
        "      number of classes\n",
        "  filters: int\n",
        "      number of filters\n",
        "  scales: int\n",
        "      number of scales\n",
        "  repeat: int\n",
        "      number of residual blocks per scale\n",
        "  dropout: float\n",
        "      dropout ratio (None indicates dropout is unused)\n",
        "\n",
        "```\n",
        "\n",
        "We have numclasses which is a part of model arguments in the config file and can be accessed by cfg.model.numclasses\n",
        "\n",
        "***Note: Instead of as dictionary objects, we load config files as dotmap objects. Hence, we can use dot notation (e.g., cfg.model) or original dictionary notation (e.g., cfg['model']) to access the elements. However, we suggest the usage of dot notation for consistency purposes***\n",
        "\n",
        "\n",
        "      "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f97m03ZbqvNK"
      },
      "source": [
        "from cords.utils.models import WideResNet\n",
        "\n",
        "scale = int(np.ceil(np.log2(img_size)))\n",
        "\n",
        "#Defining the model and copies the model to the device mentioned in train_args.device argument in config file\n",
        "model = WideResNet(cfg.model.numclasses, 32, scale, 4).to(cfg.train_args.device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Defining Teacher Model\n",
        "\n",
        "Some SSL algorithms use a teacher model to estimate the consistency loss. We will be using the argument cfg.ssl_args.ema_teacher in the config file to denote as a boolean indicator for the usage of the teacher model. In our example, where we use the VAT algorithm, which does not use a teacher model. So, we can set the cfg.ssl_args.ema_teacher argument to be False.\n",
        "\n",
        "In cases where we use teacher model, we may need to mention additional arguments like cfg.ssl_args.ema_teacher_warmup and cfg.ssl_args.ema_teacher_factor which are specifically required for calculating the teacher model properties using exponential moving average."
      ],
      "metadata": {
        "id": "kY3HImbxbRQ3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# build teacher model\n",
        "scale = int(np.ceil(np.log2(img_size)))\n",
        "if cfg.ssl_args.ema_teacher:\n",
        "    teacher_model = WideResNet(cfg.model.numclasses, 32, scale, 4).to(cfg.train_args.device)\n",
        "    teacher_model.load_state_dict(model.state_dict())\n",
        "else:\n",
        "    teacher_model = None"
      ],
      "metadata": {
        "id": "w6doDzn1bQ9S"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Defining Evaluation Model\n",
        "\n",
        "We can evaluate SSL algorithms on exponential moving average model or just on the model itself. We will be using the argument cfg.ssl_eval_args.weight_average in the config file to denote as a boolean indicator for the usage of the exponential weight average model for evaluation. In our example,\n",
        "we will not be using weight avearge for evaluation. So, we can set the cfg.ssl_eval_args.weight_average argument to be False.\n",
        "\n",
        "In cases where we use teacher model, we may need to mention additional arguments like cfg.ssl_args.ema_teacher_warmup and cfg.ssl_args.ema_teacher_factor which are specifically required for calculating the teacher model properties using exponential moving average."
      ],
      "metadata": {
        "id": "qmcH9ePeBjHT"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# for evaluation\n",
        "scale = int(np.ceil(np.log2(img_size)))\n",
        "if cfg.ssl_eval_args.weight_average:\n",
        "    average_model = WideResNet(cfg.model.numclasses, 32, scale, 4).to(cfg.train_args.device)\n",
        "    average_model.load_state_dict(model.state_dict())\n",
        "else:\n",
        "    average_model = None\n"
      ],
      "metadata": {
        "id": "EmIUf1VFc2u6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get SSL consistency loss functions \n",
        "\n",
        "gen_consistency function is implemented in the following file 'cords/selectionstrategies/helpers/ssl_lib/consistency/builder file' and it can be imported as follows:\n",
        "```\n",
        "from cords.selectionstrategies.helpers.ssl_lib.consistency.builder import gen_consistency\n",
        "```\n",
        "Existing Consistency loss functions are:\n",
        "1.   Cross-Entropy Loss\n",
        "2.   Squared Loss\n",
        "\n",
        "** Note that we generate two versions of loss functions with mean reduction and without mean reduction. Loss function without mean reduction is used for data subset selection as most of the subset selection strategies need individual loss gradients. Hence, using a loss function without reduction helps calculate these individual loss gradients.**\n",
        "\n",
        "We will be using ssl_args configuration arguments for generating the consistency function.\n",
        "\n"
      ],
      "metadata": {
        "id": "SiKPjvOADqMd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from cords.selectionstrategies.helpers.ssl_lib.consistency.builder import gen_consistency\n",
        "\n",
        "consistency = gen_consistency(cfg.ssl_args.consis, cfg)\n",
        "consistency_nored = gen_consistency(cfg.ssl_args.consis + '_red', cfg)"
      ],
      "metadata": {
        "id": "zScOt6dHDo07"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Defining SSL algorithm\n",
        "\n",
        "We integrated various consistency based SSL algorithms implemented in this awesome [repository](https://github.com/perrying/pytorch-consistency-regularization) with cords. These SSL algorithms can be imported by using gen_ssl_alg function implemented in cords.selectionstrategies.helpers.ssl_lib.algs.builder which can be imported as follows:\n",
        "\n",
        "```\n",
        "from cords.selectionstrategies.helpers.ssl_lib.algs.builder import gen_ssl_alg\n",
        "```\n",
        "\n",
        "In our example, we will be using VAT as SSL algorithm."
      ],
      "metadata": {
        "id": "QHhKdWovCelB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from cords.selectionstrategies.helpers.ssl_lib.algs.builder import gen_ssl_alg\n",
        "\n",
        "ssl_alg = gen_ssl_alg(cfg.ssl_args.alg, cfg)"
      ],
      "metadata": {
        "id": "n-8Z9zKbCdub"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "max_iteration = int(cfg.train_args.iteration * cfg.dss_args.fraction)"
      ],
      "metadata": {
        "id": "fwEl97QQFs2i"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lnL-sve1qrnP"
      },
      "source": [
        "### Create unlabeled, labeled and test dataloaders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QnvXqaGbqnhH"
      },
      "source": [
        "#Creating full unlabeled data loader with shuffle set to be False\n",
        "ult_seq_loader = DataLoader(ult_data, batch_size=cfg.dataloader.ul_batch_size,\n",
        "                                    shuffle=False, pin_memory=True)\n",
        "\n",
        "#Creating labeled data loader with shuffle set to be False\n",
        "lt_seq_loader = DataLoader(lt_data, batch_size=cfg.dataloader.l_batch_size,\n",
        "                            shuffle=False, pin_memory=True)\n",
        "\n",
        "#Creating test data loader with shuffle set to be False\n",
        "test_loader = DataLoader(\n",
        "    test_data,\n",
        "    1,\n",
        "    shuffle=False,\n",
        "    drop_last=False,\n",
        "    num_workers=cfg.dataloader.num_workers\n",
        ")\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vq_ehn_0vPjZ"
      },
      "source": [
        "### Instantiating RETRIEVE subset selection dataloader for unlabeled data\n",
        "\n",
        "We instantiate subset dataloaders that can be used for training the models with adaptive subsets.\n",
        "\n",
        "Each subset dataloader needs data selection strategy arguments in the form of a dotmap dictionary, logger and dataloader specific arguments like batch size, shuffle etc. We will be using dss_args in config file along with some additional arguments required for RETRIEVE.\n",
        "\n",
        "Additional arguments required for RETRIEVEDataLoader on top of dss_args in the config file are:\n",
        "\n",
        "* model\n",
        "* teacher_model\n",
        "* ssl_alg\n",
        "* consistency_nored\n",
        "* num_classes\n",
        "* max_iteration\n",
        "* learning rate\n",
        "* device\n",
        "\n",
        "We are instantiating RETRIEVE dataloader here with warm start. But any dataloader can be instantiated in the same way by passing the required arguments\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8TNMpF36xykF"
      },
      "source": [
        "from cords.utils.data.dataloader.SSL.adaptive import RETRIEVEDataLoader\n",
        "from dotmap import DotMap\n",
        "\n",
        "cfg.dss_args.model = model\n",
        "cfg.dss_args.tea_model = teacher_model\n",
        "cfg.dss_args.ssl_alg = ssl_alg\n",
        "cfg.dss_args.loss = consistency_nored\n",
        "cfg.dss_args.num_classes = num_classes\n",
        "cfg.dss_args.num_iters = max_iteration\n",
        "cfg.dss_args.eta = cfg.optimizer.lr\n",
        "cfg.dss_args.device = cfg.train_args.device\n",
        "\n",
        "ult_loader = RETRIEVEDataLoader(ult_seq_loader, lt_seq_loader, cfg.dss_args, logger=logger,\n",
        "                                batch_size=cfg.dataloader.ul_batch_size,\n",
        "                                pin_memory=cfg.dataloader.pin_memory,\n",
        "                                num_workers=cfg.dataloader.num_workers)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get Optimizer\n",
        "\n",
        "We store optimizer related arguments in the optimizer option of the configuration file. In our example, we will be using \"sgd\" optimizer with Nesterov momentum without any weight decay. The config.optimizer arguments in our example are as follows:\n",
        "\n",
        "```\n",
        "optimizer=dict(type=\"sgd\",\n",
        "                momentum=0.9,\n",
        "                lr=0.03,\n",
        "                weight_decay=0,\n",
        "                nesterov=True,\n",
        "                tsa=False,\n",
        "                tsa_schedule='linear')\n",
        "```"
      ],
      "metadata": {
        "id": "f9TqCLfbS0Oz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "if cfg.optimizer.type == \"sgd\":\n",
        "    optimizer = optim.SGD(\n",
        "                model.parameters(), cfg.optimizer.lr, cfg.optimizer.momentum, \n",
        "                weight_decay=cfg.optimizer.weight_decay, nesterov=cfg.optimizer.nesterov)\n",
        "elif cfg.optimizer.type == \"adam\":\n",
        "    optimizer = optim.Adam(\n",
        "        model.parameters(), cfg.optimizer.lr, (cfg.optimizer.momentum, 0.999), \n",
        "        weight_decay=cfg.optimizer.weight_decay)\n",
        "else:\n",
        "    raise NotImplementedError\n"
      ],
      "metadata": {
        "id": "MGq41g3aSzqn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get Scheduler\n",
        "\n",
        "We store scheduler related arguments in the scheduler option of the configuration file. In our example, we will be using cosine-annealing scheduler. The config.scheduler arguments in our example are as follows:\n",
        "\n",
        "```\n",
        "scheduler=dict(lr_decay=\"cos\",\n",
        "              warmup_iter=0),\n",
        "\n",
        "```"
      ],
      "metadata": {
        "id": "QX-iP0FbXnVv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# set lr scheduler\n",
        "if cfg.scheduler.lr_decay == \"cos\":\n",
        "    if cfg.dss_args.type == 'Full':\n",
        "        lr_scheduler = scheduler.CosineAnnealingLR(optimizer, max_iteration)\n",
        "    else:\n",
        "        lr_scheduler = scheduler.CosineAnnealingLR(optimizer,\n",
        "                                                    cfg.train_args.iteration * cfg.dss_args.fraction)\n",
        "elif cfg.scheduler.lr_decay == \"step\":\n",
        "    # TODO: fixed milestones\n",
        "    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [400000, ], cfg.scheduler.lr_decay_rate)\n",
        "else:\n",
        "    raise NotImplementedError\n"
      ],
      "metadata": {
        "id": "-TzvQ2H-XinL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### SSL Model Parameters Update function"
      ],
      "metadata": {
        "id": "2Ytyl3sgZzoU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"\n",
        "############################## Model Parameters Update ##############################\n",
        "\"\"\"\n",
        "\n",
        "def param_update(cfg,\n",
        "                cur_iteration,\n",
        "                model,\n",
        "                teacher_model,\n",
        "                optimizer,\n",
        "                ssl_alg,\n",
        "                consistency,\n",
        "                labeled_data,\n",
        "                ul_weak_data,\n",
        "                ul_strong_data,\n",
        "                labels,\n",
        "                average_model,\n",
        "                weights=None,\n",
        "                ood=False\n",
        "                ):\n",
        "    start_time = time.time()\n",
        "    # Concantenate labeled data, weakly augmented, and strongly augmented unlabeled data\n",
        "    all_data = torch.cat([labeled_data, ul_weak_data, ul_strong_data], 0)\n",
        "    forward_func = model.forward\n",
        "    stu_logits = forward_func(all_data)\n",
        "    labeled_preds = stu_logits[:labeled_data.shape[0]]\n",
        "\n",
        "    # Separate weak unlabeled logits, and strong unlabeled logits\n",
        "    stu_unlabeled_weak_logits, stu_unlabeled_strong_logits = torch.chunk(stu_logits[labels.shape[0]:], 2, dim=0)\n",
        "    \n",
        "    # Use training signal annealing (TSA)\n",
        "    if cfg.optimizer.tsa:\n",
        "        none_reduced_loss = F.cross_entropy(labeled_preds, labels, reduction=\"none\")\n",
        "        L_supervised = alg_utils.anneal_loss(\n",
        "            labeled_preds, labels, none_reduced_loss, cur_iteration + 1,\n",
        "            cfg.train_args.iteration, labeled_preds.shape[1], cfg.optimizer.tsa_schedule)\n",
        "    else:\n",
        "        L_supervised = F.cross_entropy(labeled_preds, labels)\n",
        "\n",
        "    # IF SSL coefficient is greater than zero, calculate the consistency loss\n",
        "    if cfg.ssl_args.coef > 0:\n",
        "        # get target values\n",
        "        if teacher_model is not None:  # get target values from teacher model\n",
        "            t_forward_func = teacher_model.forward\n",
        "            tea_logits = t_forward_func(all_data)\n",
        "            tea_unlabeled_weak_logits, _ = torch.chunk(tea_logits[labels.shape[0]:], 2, dim=0)\n",
        "        else:\n",
        "            t_forward_func = forward_func\n",
        "            tea_unlabeled_weak_logits = stu_unlabeled_weak_logits\n",
        "\n",
        "        # calculate consistency loss\n",
        "        model.update_batch_stats(False)\n",
        "        y, targets, mask = ssl_alg(\n",
        "            stu_preds=stu_unlabeled_strong_logits,\n",
        "            tea_logits=tea_unlabeled_weak_logits.detach(),\n",
        "            w_data=ul_strong_data,\n",
        "            subset=False,\n",
        "            stu_forward=forward_func,\n",
        "            tea_forward=t_forward_func\n",
        "        )\n",
        "        model.update_batch_stats(True)\n",
        "\n",
        "        # calculate weighted consistency loss\n",
        "        if weights is None:\n",
        "            L_consistency = consistency(y, targets, mask, weak_prediction=tea_unlabeled_weak_logits.softmax(1))\n",
        "        else:\n",
        "            L_consistency = consistency(y, targets, mask * weights,\n",
        "                                        weak_prediction=tea_unlabeled_weak_logits.softmax(1))\n",
        "    else:\n",
        "        L_consistency = torch.zeros_like(L_supervised)\n",
        "        mask = None\n",
        "\n",
        "    # calculate total loss\n",
        "    coef = scheduler.exp_warmup(cfg.ssl_args.coef, int(cfg.scheduler.warmup_iter), cur_iteration + 1)\n",
        "    loss = L_supervised + coef * L_consistency\n",
        "    if cfg.ssl_args.em > 0:\n",
        "        loss -= cfg.ssl_args.em * \\\n",
        "                (stu_unlabeled_weak_logits.softmax(1) * F.log_softmax(stu_unlabeled_weak_logits, 1)).sum(1).mean()\n",
        "\n",
        "    # update parameters\n",
        "    cur_lr = optimizer.param_groups[0][\"lr\"]\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    if cfg.optimizer.weight_decay > 0:\n",
        "        decay_coeff = cfg.optimizer.weight_decay * cur_lr\n",
        "        model_utils.apply_weight_decay(model.modules(), decay_coeff)\n",
        "    optimizer.step()\n",
        "\n",
        "    # update teacher parameters by exponential moving average\n",
        "    if cfg.ssl_args.ema_teacher:\n",
        "        model_utils.ema_update(\n",
        "            teacher_model, model, cfg.ssl_args.ema_teacher_factor,\n",
        "            cfg.optimizer.weight_decay * cur_lr if cfg.ssl_args.ema_apply_wd else None,\n",
        "            cur_iteration if cfg.ssl_args.ema_teacher_warmup else None)\n",
        "    \n",
        "    # update evaluation model's parameters by exponential moving average\n",
        "    if cfg.ssl_eval_args.weight_average:\n",
        "        model_utils.ema_update(\n",
        "            average_model, model, cfg.ssl_eval_args.wa_ema_factor,\n",
        "            cfg.optimizer.weight_decay * cur_lr if cfg.ssl_eval_args.wa_apply_wd else None)\n",
        "\n",
        "    # calculate accuracy for labeled data\n",
        "    acc = (labeled_preds.max(1)[1] == labels).float().mean()\n",
        "\n",
        "    return {\n",
        "        \"acc\": acc,\n",
        "        \"loss\": loss.item(),\n",
        "        \"sup loss\": L_supervised.item(),\n",
        "        \"ssl loss\": L_consistency.item(),\n",
        "        \"mask\": mask.float().mean().item() if mask is not None else 1,\n",
        "        \"coef\": coef,\n",
        "        \"sec/iter\": (time.time() - start_time)\n",
        "    }\n"
      ],
      "metadata": {
        "id": "JmrRh3S1ZuUt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### SSL model evaluation function\n",
        "\n",
        "Function that evaluates the raw SSL model and EMA evaluation model if any on test dataloader to calculate accuracy and loss metrics"
      ],
      "metadata": {
        "id": "y71bNVrDbbqj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def evaluation(raw_model, eval_model, loader, device):\n",
        "    raw_model.eval()\n",
        "    eval_model.eval()\n",
        "    sum_raw_acc = sum_acc = sum_loss = 0\n",
        "    with torch.no_grad():\n",
        "        for (data, labels) in loader:\n",
        "            data, labels = data.to(device), labels.to(device)\n",
        "            preds = eval_model(data)\n",
        "            raw_preds = raw_model(data)\n",
        "            loss = F.cross_entropy(preds, labels)\n",
        "            sum_loss += loss.item()\n",
        "            acc = (preds.max(1)[1] == labels).float().mean()\n",
        "            raw_acc = (raw_preds.max(1)[1] == labels).float().mean()\n",
        "            sum_acc += acc.item()\n",
        "            sum_raw_acc += raw_acc.item()\n",
        "    mean_raw_acc = sum_raw_acc / len(loader)\n",
        "    mean_acc = sum_acc / len(loader)\n",
        "    mean_loss = sum_loss / len(loader)\n",
        "    raw_model.train()\n",
        "    eval_model.train()\n",
        "    return mean_raw_acc, mean_acc, mean_loss\n"
      ],
      "metadata": {
        "id": "FIiITLWKbbeX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### SSL Training loop\n",
        "\n",
        "In SSL training loop, we iterate over batches of labeled and unlabeled data subset selected. We can do this by iterating over labeled and RETRIEVEDataloader as follows:\n",
        "\n",
        "```\n",
        "for batch_idx, (l_data, ul_data) in enumerate(zip(lt_loader, ult_loader)):\n",
        "  # ult_loader is an object of RETRIEVEDataloader class\n",
        "```"
      ],
      "metadata": {
        "id": "-oAMCMlISbd1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model.train()\n",
        "logger.info(model)\n",
        "\n",
        "# init meter for metrics logging\n",
        "metric_meter = Meter()\n",
        "test_acc_list = []\n",
        "raw_acc_list = []\n",
        "logger.info(\"training\")\n",
        "\n",
        "iter_count = 1\n",
        "subset_selection_time = 0\n",
        "training_time = 0\n",
        "\n",
        "# Start training until maximum number of iterations are reached\n",
        "while iter_count <= max_iteration:\n",
        "    lt_loader = DataLoader(\n",
        "        lt_data,\n",
        "        cfg.dataloader.l_batch_size,\n",
        "        sampler=dataset_utils.InfiniteSampler(len(lt_data), len(list(\n",
        "            ult_loader.batch_sampler)) * cfg.dataloader.l_batch_size),\n",
        "        num_workers=cfg.dataloader.num_workers\n",
        "    )\n",
        "\n",
        "    logger.debug(\"Data loader iteration count is: {0:d}\".format(len(list(ult_loader.batch_sampler))))\n",
        "    # Enumerate on batches of labeled and unlabeled data. \n",
        "    # Note that the ult_loader enumerates only on subsets of unlabeled data selected by RETRIEVE\n",
        "    for batch_idx, (l_data, ul_data) in enumerate(zip(lt_loader, ult_loader)):\n",
        "        batch_start_time = time.time()\n",
        "        if iter_count > max_iteration:\n",
        "            break\n",
        "        l_aug, labels = l_data\n",
        "        ul_w_aug, ul_s_aug, _, weights = ul_data\n",
        "        if cfg.dataset.feature in ['ood', 'classimb']:\n",
        "            ood = True\n",
        "        else:\n",
        "            ood = False\n",
        "        params = param_update(\n",
        "                cfg, iter_count, model, teacher_model, optimizer, ssl_alg,\n",
        "                consistency, l_aug.to(cfg.train_args.device), ul_w_aug.to(cfg.train_args.device),\n",
        "                ul_s_aug.to(cfg.train_args.device), labels.to(cfg.train_args.device),\n",
        "                average_model, weights=weights.to(cfg.train_args.device), ood=ood)\n",
        "        training_time += (time.time() - batch_start_time)\n",
        "        \n",
        "        # moving average for reporting losses and accuracy\n",
        "        metric_meter.add(params, ignores=[\"coef\"])\n",
        "        \n",
        "        # display losses every cfg.disp iterations\n",
        "        if ((iter_count + 1) % cfg.train_args.disp) == 0:\n",
        "            state = metric_meter.state(\n",
        "                header=f'[{iter_count + 1}/{max_iteration}]',\n",
        "                footer=f'ssl coef {params[\"coef\"]:.4g} | lr {optimizer.param_groups[0][\"lr\"]:.4g}'\n",
        "            )\n",
        "            logger.info(state)\n",
        "        lr_scheduler.step()\n",
        "        \n",
        "        # Checkpoint model at regular intervals\n",
        "        if ((iter_count + 1) % cfg.ckpt.checkpoint) == 0 or (iter_count + 1) == max_iteration:\n",
        "            with torch.no_grad():\n",
        "                if cfg.ssl_eval_args.weight_average:\n",
        "                    eval_model = average_model\n",
        "                else:\n",
        "                    eval_model = model\n",
        "                logger.info(\"test\")\n",
        "                mean_raw_acc, mean_test_acc, mean_test_loss = evaluation(model, eval_model, test_loader,\n",
        "                                                                              cfg.train_args.device)\n",
        "                logger.info(\"test loss %f | test acc. %f | raw acc. %f\", mean_test_loss, mean_test_acc,\n",
        "                            mean_raw_acc)\n",
        "                test_acc_list.append(mean_test_acc)\n",
        "                raw_acc_list.append(mean_raw_acc)\n",
        "            torch.save(model.state_dict(), os.path.join(cfg.train_args.out_dir, \"model_checkpoint.pth\"))\n",
        "            torch.save(optimizer.state_dict(),\n",
        "                        os.path.join(cfg.train_args.out_dir, \"optimizer_checkpoint.pth\"))\n",
        "        iter_count += 1\n"
      ],
      "metadata": {
        "id": "j3SyrvRKSbE3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JSHYpMlR9JII"
      },
      "source": [
        "# Using default SSL training loop directly\n",
        "\n",
        "We have incorporated the above training loop in train_ssl.py file of CORDS which can be used by directly importing the TrainClassifier class from train_ssl function as follows:\n",
        "\n",
        "```\n",
        "from train_ssl import TrainClassifier\n",
        "```\n",
        "\n",
        "Importing Semi-Supervised learning default training loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_7--My3U9JIJ"
      },
      "source": [
        "from train_ssl import TrainClassifier"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C90z1L0j9JIK"
      },
      "source": [
        "### Loading default RETRIEVE config file for CIFAR10 dataset\n",
        "\n",
        "We can load other subset selection strategies like CRAIG, GradMatch, Random for CIFAR10 dataset by loading their respective config files.\n",
        "\n",
        "Here we give an example of instantiating a SSL training loop using RETRIEVE config file"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "f8QhXaz99JIK"
      },
      "source": [
        "fraction = 0.1\n",
        "retrieve_config_file = '/content/cords/configs/SSL/config_retrieve-warm_vat_cifar10.py'\n",
        "\n",
        "from cords.utils.config_utils import load_config_data\n",
        "\n",
        "cfg = load_config_data(retrieve_config_file)\n",
        "retrieve_trn = TrainClassifier(cfg)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7cR3wmtZ9JIK"
      },
      "source": [
        "### Default config args can be modified in the following manner\n",
        "\n",
        "We can modify the default arguments of the config file by just assigning them a new file"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WVcxHn0g9JIL"
      },
      "source": [
        "retrieve_trn.cfg.train_args.disp = 256\n",
        "retrieve_trn.cfg.train_args.device = 'cuda'\n",
        "retrieve_trn.cfg.dss_args.fraction = fraction"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "19rABQo19JIL"
      },
      "source": [
        "### Start the training process"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CxbRePmJ9JIM"
      },
      "source": [
        "retrieve_trn.train()"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}