{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "e3vZ2jMf9jHP",
        "outputId": "ba768bdb-1d18-4462-80b2-165207af27a0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "__________________________________________________\n",
            "Experiment for Ratio: 1\n",
            "save_log_path: ./Saved_Training_Data/cifar10/R_1/\n",
            "Device:  cuda\n",
            "n_c_target: {0: 2525, 1: 2525, 2: 2525, 3: 2525, 4: 2525, 5: 2525, 6: 2525, 7: 2525, 8: 2525, 9: 2525}\n",
            "N_train_total: 25250\n",
            "Files already downloaded and verified\n",
            "img_num_per_cls: [2525, 2525, 2525, 2525, 2525, 2525, 2525, 2525, 2525, 2525]\n",
            "Files already downloaded and verified\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
            "  cpuset_checked))\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "++++++++++\n",
            "cls_num_list_test: {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000}\n",
            "++++++++++\n",
            "\n",
            "Total number of samples:  25250\n",
            "cls num list:\n",
            "[2525, 2525, 2525, 2525, 2525, 2525, 2525, 2525, 2525, 2525]\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Train\t\tEpoch: 1 [197/198 (99%)] \tBatch Loss: 1.668739 \tBatch Accuracy: 0.367188:  99%|█████████▉| 197/198 [00:13<00:00, 15.09it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training per_class_acc: {0: 0.27881188118811884, 1: 0.2613861386138614, 2: 0.15247524752475247, 3: 0.13821782178217823, 4: 0.19643564356435644, 5: 0.1603960396039604, 6: 0.2899009900990099, 7: 0.2621782178217822, 8: 0.37663366336633664, 9: 0.25702970297029704}\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Analysis Mean\tEpoch: 1 [198/198 (100%)]: 100%|██████████| 198/198 [00:10<00:00, 19.37it/s]\n",
            "Analysis Cov\tEpoch: 1 [198/198 (100%)]: 100%|██████████| 198/198 [00:12<00:00, 15.31it/s]\n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:225: RuntimeWarning: divide by zero encountered in reciprocal\n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:225: RuntimeWarning: invalid value encountered in matmul\n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:226: RuntimeWarning: invalid value encountered in matmul\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test accuracy: 35.54 %\n",
            "Testing per_class_acc: {0: 0.389, 1: 0.307, 2: 0.029, 3: 0.469, 4: 0.586, 5: 0.0, 6: 0.208, 7: 0.362, 8: 0.713, 9: 0.491}\n",
            "Checkpoint saved. Epoch: 1 \n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Train\t\tEpoch: 2 [197/198 (99%)] \tBatch Loss: 1.687701 \tBatch Accuracy: 0.390625:  99%|█████████▉| 197/198 [00:12<00:00, 15.44it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training per_class_acc: {0: 0.37346534653465346, 1: 0.5611881188118812, 2: 0.13702970297029704, 3: 0.13980198019801982, 4: 0.29306930693069305, 5: 0.4087128712871287, 6: 0.45782178217821784, 7: 0.45465346534653467, 8: 0.5837623762376237, 9: 0.5017821782178218}\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Analysis Mean\tEpoch: 2 [198/198 (100%)]: 100%|██████████| 198/198 [00:10<00:00, 18.21it/s]\n",
            "Analysis Cov\tEpoch: 2 [198/198 (100%)]: 100%|██████████| 198/198 [00:12<00:00, 15.45it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test accuracy: 42.970000000000006 %\n",
            "Testing per_class_acc: {0: 0.518, 1: 0.455, 2: 0.167, 3: 0.084, 4: 0.356, 5: 0.46, 6: 0.574, 7: 0.528, 8: 0.583, 9: 0.572}\n",
            "Checkpoint saved. Epoch: 2 \n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Train\t\tEpoch: 3 [197/198 (99%)] \tBatch Loss: 1.548126 \tBatch Accuracy: 0.382812:  99%|█████████▉| 197/198 [00:13<00:00, 14.76it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training per_class_acc: {0: 0.41386138613861384, 1: 0.5900990099009901, 2: 0.19762376237623763, 3: 0.14257425742574256, 4: 0.26732673267326734, 5: 0.47287128712871285, 6: 0.5164356435643565, 7: 0.5176237623762376, 8: 0.6039603960396039, 9: 0.5398019801980198}\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Analysis Mean\tEpoch: 3 [198/198 (100%)]: 100%|██████████| 198/198 [00:10<00:00, 18.87it/s]\n",
            "Analysis Cov\tEpoch: 3 [28/198 (14%)]:  14%|█▍        | 28/198 [00:02<00:10, 15.72it/s]"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import sys\n",
        "import pickle\n",
        "import shutil\n",
        "\n",
        "import torch\n",
        "\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import matplotlib.pyplot as plt\n",
        "import torch.nn.functional as F\n",
        "import torchvision.models as models\n",
        "\n",
        "from tqdm import tqdm\n",
        "from collections import OrderedDict\n",
        "from scipy.sparse.linalg import svds\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "from generate_cifar import IMBALANCECIFAR10\n",
        "from generate_mnist import IMBALANCEMNIST\n",
        "\n",
        "\n",
        "#-------  analysis ---------------------------------------\n",
        "class graphs:\n",
        "  def __init__(self):\n",
        "    self.cur_epochs     = []\n",
        "    self.accuracy     = []\n",
        "    self.loss         = []\n",
        "    self.reg_loss     = []\n",
        "\n",
        "    self.test_loss = []\n",
        "    self.test_acc = []\n",
        "\n",
        "    # NC1\n",
        "    self.Sw_invSb     = []\n",
        "\n",
        "    # NC2\n",
        "    self.norm_M_CoV   = []\n",
        "    self.norm_W_CoV   = []\n",
        "    self.cos_M        = []\n",
        "    self.cos_W        = []\n",
        "\n",
        "    # NC3\n",
        "    self.W_M_dist     = []\n",
        "    \n",
        "    # NC4\n",
        "    self.NCC_mismatch = []\n",
        "\n",
        "    # Decomposition\n",
        "    self.MSE_wd_features = []\n",
        "    self.LNC1 = []\n",
        "    self.LNC23 = []\n",
        "    self.Lperp = []\n",
        "\n",
        "\n",
        "#------- train fcn ---------------------------------------\n",
        "def train(model, criterion, device, num_classes, train_loader, optimizer, epoch, n_c_train_target):\n",
        "    model.train()\n",
        "\n",
        "    per_class_acc = {}\n",
        "    for c in range(0, num_classes):\n",
        "        per_class_acc[c] = 0\n",
        "    \n",
        "    pbar = tqdm(total=len(train_loader), position=0, leave=True)\n",
        "    for batch_idx, (data, target) in enumerate(train_loader, start=1):\n",
        "        if data.shape[0] != batch_size:\n",
        "            continue\n",
        "        \n",
        "        data, target = data.to(device), target.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        out = model(data)\n",
        "        loss = criterion(out, target)\n",
        "        \n",
        "        predicted = torch.argmax(out, dim=1)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        accuracy = torch.mean((torch.argmax(out,dim=1)==target).float()).item()\n",
        "\n",
        "        pbar.update(1)\n",
        "        pbar.set_description(\n",
        "            'Train\\t\\tEpoch: {} [{}/{} ({:.0f}%)] \\t'\n",
        "            'Batch Loss: {:.6f} \\t'\n",
        "            'Batch Accuracy: {:.6f}'.format(\n",
        "                epoch,\n",
        "                batch_idx,\n",
        "                len(train_loader),\n",
        "                100. * batch_idx / len(train_loader),\n",
        "                loss.item(),\n",
        "                accuracy))\n",
        "        \n",
        "        for c in range(0, num_classes):\n",
        "            per_class_acc[c] += ((predicted == target) * (target == c)).sum().item()\n",
        "\n",
        "        if debug and batch_idx > 20:\n",
        "          break\n",
        "\n",
        "    for c in range(0, num_classes):\n",
        "        per_class_acc[c] /= n_c_train_target[c]\n",
        "    print(\"Training per_class_acc: \" + str(per_class_acc))\n",
        "        \n",
        "    pbar.close()\n",
        "\n",
        "    return per_class_acc\n",
        "\n",
        "\n",
        "\n",
        "#------- analysis fcn ---------------------------------------\n",
        "def analysis(graph, model, criterion_summed, device, num_classes, loader, test_loader, NC_analysis=False, cls_num_list = None, features = None, epoch = None, classifier = None, cls_num_list_test = None):\n",
        "    model.eval()\n",
        "\n",
        "    N             = [0 for _ in range(C)]\n",
        "    mean          = [0 for _ in range(C)]\n",
        "    Sw            = 0\n",
        "\n",
        "    mu_c_save = None\n",
        "\n",
        "    loss          = 0\n",
        "    net_correct   = 0\n",
        "    NCC_match_net = 0\n",
        "\n",
        "    if NC_analysis:\n",
        "      \n",
        "      for computation in ['Mean','Cov']:\n",
        "        pbar = tqdm(total=len(loader), position=0, leave=True)\n",
        "        for batch_idx, (data, target) in enumerate(loader, start=1):\n",
        "\n",
        "          data, target = data.to(device), target.to(device)\n",
        "\n",
        "          output = model(data)\n",
        "          h = features.value.data.view(data.shape[0],-1) # B CHW\n",
        "          \n",
        "          # during calculation of class means, calculate loss\n",
        "          if computation == 'Mean':\n",
        "            if str(criterion_summed) == 'CrossEntropyLoss()':\n",
        "              loss += criterion_summed(output, target).item()\n",
        "            elif str(criterion_summed) == 'MSELoss()':\n",
        "              loss += criterion_summed(output, F.one_hot(target, num_classes=num_classes).float()).item()\n",
        "\n",
        "          for c in range(C):\n",
        "            # features belonging to class c\n",
        "            idxs = (target == c).nonzero(as_tuple=True)[0]\n",
        "            \n",
        "            if len(idxs) == 0: # If no class-c in this batch\n",
        "              continue\n",
        "\n",
        "            h_c = h[idxs,:] # B CHW\n",
        "\n",
        "            if computation == 'Mean':\n",
        "              # update class means\n",
        "              mean[c] += torch.sum(h_c, dim=0) # CHW\n",
        "              N[c] += h_c.shape[0]\n",
        "              \n",
        "            elif computation == 'Cov':\n",
        "              # update within-class cov\n",
        "\n",
        "              z = h_c - mean[c].unsqueeze(0) # B CHW\n",
        "              cov = torch.matmul(z.unsqueeze(-1), # B CHW 1\n",
        "                          z.unsqueeze(1))  # B 1 CHW\n",
        "              Sw += torch.sum(cov, dim=0)\n",
        "\n",
        "              # during calculation of within-class covariance, calculate:\n",
        "              # 1) network's accuracy\n",
        "              net_pred = torch.argmax(output[idxs,:], dim=1)\n",
        "              net_correct += sum(net_pred==target[idxs]).item()\n",
        "\n",
        "              # 2) agreement between prediction and nearest class center\n",
        "              NCC_scores = torch.stack([torch.norm(h_c[i,:] - M.T,dim=1) \\\n",
        "                            for i in range(h_c.shape[0])])\n",
        "              NCC_pred = torch.argmin(NCC_scores, dim=1)\n",
        "              NCC_match_net += sum(NCC_pred==net_pred).item()\n",
        "\n",
        "          pbar.update(1)\n",
        "          pbar.set_description(\n",
        "            'Analysis {}\\t'\n",
        "            'Epoch: {} [{}/{} ({:.0f}%)]'.format(\n",
        "              computation,\n",
        "              epoch,\n",
        "              batch_idx,\n",
        "              len(loader),\n",
        "              100. * batch_idx/ len(loader)))\n",
        "          \n",
        "          if debug and batch_idx > 20:\n",
        "            break\n",
        "        pbar.close()\n",
        "        \n",
        "        if computation == 'Mean':\n",
        "          for c in range(C):\n",
        "            mean[c] /= N[c]\n",
        "            M = torch.stack(mean).T\n",
        "          loss /= sum(N)\n",
        "          mu_c_save = mean\n",
        "        elif computation == 'Cov':\n",
        "          Sw /= sum(N)\n",
        "\n",
        "      graph.loss.append(loss)\n",
        "      graph.accuracy.append(net_correct/sum(N))\n",
        "      graph.NCC_mismatch.append(1-NCC_match_net/sum(N))\n",
        "\n",
        "      # loss with weight decay\n",
        "      reg_loss = loss\n",
        "      for param in model.parameters():\n",
        "        reg_loss += 0.5 * weight_decay * torch.sum(param**2).item()\n",
        "      graph.reg_loss.append(reg_loss)\n",
        "\n",
        "      # global mean\n",
        "      muG = torch.mean(M, dim=1, keepdim=True) # CHW 1\n",
        "\n",
        "      # between-class covariance\n",
        "      M_ = M - muG\n",
        "      Sb = torch.matmul(M_, M_.T) / C\n",
        "\n",
        "      # avg norm\n",
        "      W  = classifier.weight\n",
        "      M_norms = torch.norm(M_,  dim=0)\n",
        "      W_norms = torch.norm(W.T, dim=0)\n",
        "\n",
        "      graph.norm_M_CoV.append((torch.std(M_norms)/torch.mean(M_norms)).item())\n",
        "      graph.norm_W_CoV.append((torch.std(W_norms)/torch.mean(W_norms)).item())\n",
        "\n",
        "      # tr{Sw Sb^-1}\n",
        "      Sw = Sw.cpu().numpy()\n",
        "      Sb = Sb.cpu().numpy()\n",
        "      eigvec, eigval, _ = svds(Sb, k=C-1)\n",
        "      inv_Sb = eigvec @ np.diag(eigval**(-1)) @ eigvec.T \n",
        "      graph.Sw_invSb.append(np.trace(Sw @ inv_Sb))\n",
        "\n",
        "      # ||W^T - M_||\n",
        "      normalized_M = M_ / torch.norm(M_,'fro')\n",
        "      normalized_W = W.T / torch.norm(W.T,'fro')\n",
        "      graph.W_M_dist.append((torch.norm(normalized_W - normalized_M)**2).item())\n",
        "\n",
        "      # mutual coherence\n",
        "      def coherence(V): \n",
        "        G = V.T @ V\n",
        "        G += torch.ones((C,C),device=device) / (C-1)\n",
        "        G -= torch.diag(torch.diag(G))\n",
        "        return torch.norm(G,1).item() / (C*(C-1))\n",
        "\n",
        "      graph.cos_M.append(coherence(M_/M_norms))\n",
        "      graph.cos_W.append(coherence(W.T/W_norms))\n",
        "\n",
        "\t\t\n",
        "\t# test error\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    test_loss = 0\n",
        "\n",
        "    per_class_acc = {}\n",
        "    for c in range(0, num_classes):\n",
        "        per_class_acc[c] = 0\n",
        "    \n",
        "    # since we're not training, we don't need to calculate the gradients for our outputs\n",
        "    with torch.no_grad():\n",
        "        for batch_idx, (data, target) in enumerate(test_loader, start=1):\n",
        "\n",
        "            data, target = data.to(device), target.to(device)\n",
        "\n",
        "            output = model(data)\n",
        "\n",
        "            if str(criterion_summed) == 'CrossEntropyLoss()':\n",
        "              test_loss += criterion_summed(output, target).item()\n",
        "            elif str(criterion_summed) == 'MSELoss()':\n",
        "              test_loss += criterion_summed(output, F.one_hot(target, num_classes=num_classes).float()).item()\n",
        "\n",
        "            predicted = torch.argmax(output, dim=1)\n",
        "            correct += sum(predicted==target).item()\n",
        "\n",
        "            for c in range(0, num_classes):\n",
        "                per_class_acc[c] += ((predicted == target) * (target == c)).sum().item()\n",
        "\n",
        "        test_loss /= len(test_loader.dataset)\n",
        "        acc = correct / len(test_loader.dataset)\n",
        "\n",
        "        for c in range(0, num_classes):\n",
        "            per_class_acc[c] /= cls_num_list_test[c]\n",
        "\n",
        "    graph.test_loss.append(test_loss)\n",
        "    graph.test_acc.append(acc)\n",
        "\n",
        "    print(f'Test accuracy: {100 * acc} %')\n",
        "    print(\"Testing per_class_acc: \" + str(per_class_acc))\n",
        "\n",
        "    return mu_c_save, per_class_acc\n",
        "\n",
        "###########################################################################################################################\n",
        "\n",
        "\n",
        "\n",
        "#------- parameters ---------------------------------------\n",
        "debug = False # Only runs 20 batches per epoch for debugging\n",
        "\n",
        "# dataset parameters\n",
        "im_size             = 32\n",
        "padded_im_size      = 32\n",
        "C                   = 10\n",
        "\n",
        "# Optimization Criterion\n",
        "loss_name = 'CrossEntropyLoss'\n",
        "\n",
        "# Optimization hyperparameters\n",
        "lr_decay            = 0.1\n",
        "\n",
        "dataset_name = \"cifar10\"\n",
        "\n",
        "epochs              = 350\n",
        "epochs_lr_decay     = [epochs//3, epochs*2//3]\n",
        "\n",
        "batch_size          = 128\n",
        "\n",
        "momentum            = 0.9\n",
        "weight_decay        = 5e-4\n",
        "\n",
        "\n",
        "\n",
        "def run():\n",
        "\n",
        "  root_path = \"./Saved_Training_Data/\"\n",
        "  if dataset_name == \"cifar10\":\n",
        "    input_ch            = 3\n",
        "    root_path += \"cifar10/\"\n",
        "  elif dataset_name == \"mnist\":\n",
        "    input_ch            = 1\n",
        "    root_path += \"mnist/\"\n",
        "  \n",
        "  for R in [1,5,10,100]:\n",
        "    \n",
        "    print(\"_\" * 50)\n",
        "    print(\"Experiment for Ratio: \" + str(R))\n",
        "    \n",
        "    experiment_name = \"R_\" + str(R)\n",
        "    save_log_path = root_path + experiment_name + \"/\"\n",
        "\n",
        "    experiment_complete_flag_file = save_log_path + \"/ExpComplete.txt\"\n",
        "    print(\"save_log_path: \" + str(save_log_path))\n",
        "    if not os.path.exists(save_log_path):\n",
        "        os.makedirs(save_log_path, exist_ok=True)\n",
        "    elif not os.path.exists(experiment_complete_flag_file):\n",
        "        shutil.rmtree(save_log_path)\n",
        "        os.makedirs(save_log_path, exist_ok=True)\n",
        "    else:\n",
        "        print(\"Skipping this experiments, already done ...\")\n",
        "        continue\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    print(\"Device: \", device)\n",
        "\n",
        "    imb_type = 'step'\n",
        "    imb_factor = 1/R\n",
        "    rand_number = 0\n",
        "    workers = 4\n",
        "    train_sampler = None\n",
        "    NC_analysis = True\n",
        "\n",
        "    # analysis parameters\n",
        "    epoch_list          = [1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,\n",
        "                          12,  13,  14,  16,  17,  19,  20,  22,  24,  27,   29,\n",
        "                          32,  35,  38,  42,  45,  50,  54,  59,  65,  71,   77,\n",
        "                          85,  92,  101, 110, 121, 132, 144, 158, 172, 188,  206,\n",
        "                          225, 245, 268, 293, 320, 350]\n",
        "\n",
        "    # output parameters\n",
        "    data_path = 'data/' \n",
        "    save_path = '' \n",
        "    model_path = 'output/model.pt' \n",
        "    file_path = ''\n",
        "    fig_path = 'output/'\n",
        "    # PATH = 'output/model.pt'\n",
        "    \n",
        "    N_maj_dict = {1:2525, 2:3366, 5: 4208, 10:4591, 20: 4809, 50: 4950, 100:5000}\n",
        "    N_min_dict = {1:2525, 2:1683, 5: 841, 10:459, 20: 240, 50: 99, 100:50}\n",
        "    N_test = 1000\n",
        "    maj_classes = [0,1,2,3,4]\n",
        "    min_classes = [5,6,7,8,9]\n",
        "    classes = maj_classes + min_classes\n",
        "\n",
        "    N_maj = N_maj_dict[R]\n",
        "    N_min = N_min_dict[R]\n",
        "\n",
        "    n_c_train_target = {}\n",
        "    for c in maj_classes:\n",
        "        n_c_train_target[c] = N_maj\n",
        "    for c in min_classes:\n",
        "        n_c_train_target[c] = N_min\n",
        "    print(\"n_c_target: \" + str(n_c_train_target))\n",
        "    N_train_total = sum(n_c_train_target.values())\n",
        "    print(\"N_train_total: \" + str(N_train_total))\n",
        "    \n",
        "\n",
        "    #-------  model ---------------------------------------\n",
        "    model = models.resnet18(pretrained=False, num_classes=C)\n",
        "    model.conv1 = nn.Conv2d(input_ch, model.conv1.weight.shape[0], 3, 1, 1, bias=False) # Small dataset filter size used by He et al. (2015)\n",
        "    model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)\n",
        "    model.fc = nn.Linear(in_features=512, out_features=10, bias=False)\n",
        "    model = model.to(device)\n",
        "\n",
        "    class features:\n",
        "        pass\n",
        "\n",
        "    def hook(self, input, output):\n",
        "        features.value = input[0].clone()\n",
        "\n",
        "    # register hook that saves last-layer input into features\n",
        "    classifier = model.fc\n",
        "    classifier.register_forward_hook(hook)\n",
        "\n",
        "\n",
        "\n",
        "    #-------  dataset imbalance -------------------------------------\n",
        "\n",
        "    if dataset_name == \"cifar10\":\n",
        "      transform_train = transforms.Compose([\n",
        "          transforms.RandomCrop(32, padding=4),\n",
        "          transforms.RandomHorizontalFlip(),\n",
        "          transforms.ToTensor(),\n",
        "          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "      ])\n",
        "\n",
        "      transform_val = transforms.Compose([\n",
        "          transforms.ToTensor(),\n",
        "          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "      ])\n",
        "\n",
        "\n",
        "\n",
        "      train_dataset = IMBALANCECIFAR10(root='./class_imbalance/data', imb_type=\"step\", imb_factor=R,\n",
        "                                            rand_number=1, train=True, download=True,\n",
        "                                            transform=transform_train, n_c_train_target = n_c_train_target, classes = classes)\n",
        "      val_dataset = datasets.CIFAR10(root='./class_imbalance/data', train=False, download=True, transform=transform_val)\n",
        "\n",
        "      cls_num_list_test = {}\n",
        "      for c in range(0, C):\n",
        "          cls_num_list_test[c] = 1000\n",
        "      \n",
        "    elif dataset_name == \"mnist\":\n",
        "      transform = transforms.Compose([transforms.Pad((padded_im_size - im_size)//2),\n",
        "                                transforms.ToTensor(),\n",
        "                                transforms.Normalize(0.1307,0.3081)])\n",
        "\n",
        "\n",
        "\n",
        "      train_dataset = IMBALANCEMNIST(root='./class_imbalance/data', imb_type=\"step\", imb_factor=R,\n",
        "                                            rand_number=1, train=True, download=True,\n",
        "                                            transform=transform, n_c_train_target = n_c_train_target, classes = classes)\n",
        "      val_dataset = datasets.MNIST(root='./class_imbalance/data', train=False, download=True, transform=transform)\n",
        "    \n",
        "      cls_num_list_test = {}\n",
        "      for c in range(0, C):\n",
        "          cls_num_list_test[c] = 0\n",
        "      for label in val_dataset.targets:\n",
        "          cls_num_list_test[label.item()] += 1\n",
        "    \n",
        "      train_dataset.data = torch.tensor(train_dataset.data)\n",
        "    print(\"+\" * 10)\n",
        "    print(\"cls_num_list_test: \" + str(cls_num_list_test))\n",
        "    print(\"+\" * 10)\n",
        "        \n",
        "\n",
        "    cls_num_list = train_dataset.get_cls_num_list()\n",
        "    cls_priors = [cls_num / sum(cls_num_list) for cls_num in cls_num_list]\n",
        "    print('\\nTotal number of samples: ', sum(cls_num_list))\n",
        "    print('cls num list:')\n",
        "    print(cls_num_list)\n",
        "\n",
        "\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),\n",
        "        num_workers=workers, pin_memory=True, sampler=train_sampler)\n",
        "\n",
        "    analysis_loader = torch.utils.data.DataLoader(\n",
        "        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),\n",
        "        num_workers=workers, pin_memory=True, sampler=train_sampler)\n",
        "\n",
        "    test_loader = torch.utils.data.DataLoader(\n",
        "        val_dataset, batch_size=batch_size, shuffle=False,\n",
        "        num_workers=workers, pin_memory=True)\n",
        "\n",
        "\n",
        "    #-------  optimizer ---------------------------------------\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    criterion_summed = nn.CrossEntropyLoss(reduction='sum')\n",
        "\n",
        "\n",
        "    # Best lr after hyperparameter tuning\n",
        "    if dataset_name == \"mnist\":\n",
        "      lr = 0.0679\n",
        "    elif dataset_name == \"cifar10\":\n",
        "      lr = 1e-1\n",
        "    optimizer = optim.SGD(model.parameters(),\n",
        "                          lr=lr,\n",
        "                          momentum=momentum,\n",
        "                          weight_decay=weight_decay)\n",
        "\n",
        "    # Optimization hyperparameters\n",
        "    lr_decay            = 0.1\n",
        "    epochs_lr_decay     = [epochs//3, epochs*2//3]\n",
        "    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,\n",
        "                                                  milestones=epochs_lr_decay,\n",
        "                                                  gamma=lr_decay)\n",
        "      \n",
        "      \n",
        "    graph = graphs()\n",
        "\n",
        "    cur_epochs = []\n",
        "\n",
        "    mu_c_list_train = []\n",
        "    W_list = []\n",
        "    B_list = []\n",
        "    train_accuracies_list = []\n",
        "    test_accuracies_list = []\n",
        "\n",
        "    for epoch in range(1, epochs + 1):\n",
        "        per_class_acc_train = train(model, criterion, device, C, train_loader, optimizer, epoch, n_c_train_target)\n",
        "        lr_scheduler.step()\n",
        "        \n",
        "        if epoch in epoch_list:\n",
        "            cur_epochs.append(epoch)\n",
        "            mu_c_save, per_class_acc_test = analysis(graph, model, criterion_summed, device, C, analysis_loader, test_loader, NC_analysis=NC_analysis, cls_num_list = cls_num_list, features = features, epoch = epoch, classifier = classifier, cls_num_list_test = cls_num_list_test)\n",
        "          \n",
        "            graph.cur_epochs = cur_epochs\n",
        "            f1 = open(save_log_path+'graphs_save.pkl', \"wb\")\n",
        "            pickle.dump(graph, f1)\n",
        "            f1.close()\n",
        "\n",
        "            W = classifier.weight.to(\"cpu\")\n",
        "            Bias = classifier.bias\n",
        "            if Bias == None:\n",
        "                Bias = torch.zeros((C), requires_grad=True).to(\"cpu\")\n",
        "            else:\n",
        "                Bias = Bias.to(\"cpu\")\n",
        "            W_list.append(W)\n",
        "            B_list.append(Bias)\n",
        "\n",
        "            mu_c_list_train.append(mu_c_save)\n",
        "            train_accuracies_list.append(per_class_acc_train)\n",
        "            test_accuracies_list.append(per_class_acc_test)\n",
        "\n",
        "            print(f'Checkpoint saved. Epoch: {epoch} ')\n",
        "        \n",
        "\n",
        "    torch.save(mu_c_list_train, save_log_path + \"mu_c_list_train\")\n",
        "    torch.save(W_list, save_log_path + \"W_list\")\n",
        "    torch.save(B_list, save_log_path + \"B_list\")\n",
        "    torch.save(train_accuracies_list, save_log_path + \"train_accuracies_list\")\n",
        "    torch.save(test_accuracies_list,  save_log_path + \"test_accuracies_list\")\n",
        "\n",
        "\n",
        "    mu_c_list_train = torch.load(save_log_path + \"mu_c_list_train\")\n",
        "    W_list = torch.load(save_log_path + \"W_list\")\n",
        "    B_list = torch.load(save_log_path + \"B_list\")\n",
        "    train_accuracies_list = torch.load(save_log_path + \"train_accuracies_list\")\n",
        "    test_accuracies_list = torch.load(save_log_path + \"test_accuracies_list\")\n",
        "\n",
        "\n",
        "    os.makedirs(experiment_complete_flag_file, exist_ok=True)\n",
        "\n",
        "run()\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "NC_SELI_Imbalance.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
