{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook presents the core implementations of our method and uses that to train ensembles of 2 or 3 ResNet-18 models on CIFAR-10. It computes the transferability rates for ```Orig```, ```C=1.0```, and ```LOTOS C=1.0 mal=0.8``` using white-box attack as explained in our paper and generates similar plots to the ones we have in the experiments section:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision\n",
    "from advertorch.attacks.utils import attack_whole_dataset\n",
    "from advertorch.attacks import LinfPGDAttack\n",
    "\n",
    "from models.resnet import ResNet18\n",
    "from models.resnet_orig import ResNet18_orig\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "\n",
    "from utils.Empirical.utils_ensemble import AverageMeter, requires_grad_\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setting a set of parameters that are the default ones for our experiments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "freq = 50\n",
    "conv_freq = 1\n",
    "effective_epoch = 0\n",
    "conv_only = True\n",
    "conv_1st_only = False\n",
    "num_models = 2\n",
    "opt_iter = 1\n",
    "clip_steps = 100\n",
    "bn_flag = False\n",
    "epochs = 121\n",
    "in_chan = 3\n",
    "num_classes = 10\n",
    "bottom_clip = 0.8 # mal value\n",
    "cat_bottom_clip = 0.8 # mal value for the concatenation when batchnorm is used\n",
    "conv_factor = 0.05\n",
    "cat_factor = 0.05\n",
    "adv_eps = 0.04"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper function to extract and compute the average of transferability rate, roubstness, and accuracy for the models in the ensemble for each epoch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_trans(sub_df, model_count=3):\n",
    "    ### computing the average of the transferability rates from pairs of models\n",
    "    if model_count == 3:\n",
    "        trans_mat = sub_df[['t0', 't1', 't2']].values\n",
    "    elif model_count == 2:\n",
    "        trans_mat = sub_df[['t0', 't1']].values\n",
    "    acc_avg = sub_df['acc'].mean()\n",
    "\n",
    "    # compute the sum of off-diagonal values of trans_mat:\n",
    "    off_diagonal_sum = np.sum(trans_mat) - np.trace(trans_mat)\n",
    "    if model_count == 3:\n",
    "        trans_rate = off_diagonal_sum / 6.\n",
    "    elif model_count == 2:\n",
    "        trans_rate = off_diagonal_sum / 2.\n",
    "\n",
    "    if model_count == 3:\n",
    "        robustness = 1. - np.trace(trans_mat) / 3.\n",
    "    elif model_count == 2:\n",
    "        robustness = 1. - np.trace(trans_mat) / 2.\n",
    "\n",
    "    return trans_rate, robustness, acc_avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "def summarize(df, num_models):\n",
    "    trans_df_1_group = df.groupby('epoch')\n",
    "    results = []\n",
    "    for idx, group in trans_df_1_group:\n",
    "        group = group[:num_models]\n",
    "        trans_rate, robustness, acc_avg = compute_trans(group, num_models)\n",
    "        results.append([int(idx),trans_rate, robustness, acc_avg])\n",
    "    res_df = pd.DataFrame(results, columns=['epoch', 'trans', 'robust', 'acc'])\n",
    "    return res_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper function to use the white-box attacks (see our paper for details) to compute the transferability rate of adversarial examples between each ordered pair of models in the ensemble:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaltrans_correct(loader, models, criterion, epoch):\n",
    "    ### performing white-box attack to compute the transferability rates\n",
    "    for i in range(len(models)):\n",
    "        models[i].eval()\n",
    "\n",
    "    adv = []\n",
    "    advsamples_lst = []\n",
    "    pred_lst = []\n",
    "    advpred_lst = []   \n",
    "    label = None\n",
    "    for i in range(len(models)):\n",
    "        curmodel = models[i]\n",
    "        adversary = LinfPGDAttack(\n",
    "            curmodel, loss_fn=criterion, eps=adv_eps,\n",
    "            nb_iter=50, eps_iter=adv_eps / 10, rand_init=True, clip_min=0., clip_max=1.,\n",
    "            targeted=False)\n",
    "\n",
    "        adv.append(adversary)\n",
    "        advsamples, label, pred, advpred = attack_whole_dataset(adversary, loader, device=device)\n",
    "        advsamples_lst.append(advsamples)\n",
    "        pred_lst.append(pred)\n",
    "        advpred_lst.append(advpred)\n",
    "\n",
    "    trans_list = []\n",
    "    accs = np.zeros(num_models)\n",
    "    trans = np.zeros((num_models, num_models))\n",
    "    for i in range(len(models)):\n",
    "        _, label, pred, advpred = advsamples_lst[i], label, pred_lst[i], advpred_lst[i]\n",
    "\n",
    "        for j in range(len(models)):\n",
    "            if j==i:\n",
    "                if trans[i][j] > 0.00000001:\n",
    "                    continue\n",
    "                y = label[label == pred]\n",
    "                accs[i] = y.size(0) / label.size(0)\n",
    "                y_wrong = label[(label == pred) & (advpred != label)]\n",
    "                trans[i][j] = y_wrong.size(0) / len(y)\n",
    "\n",
    "            else:\n",
    "                inputc = _[(label == pred) & (advpred != label) & (label==pred_lst[j])]\n",
    "                print('model: ', i, inputc.size(0), ' out of ', _.size(0))\n",
    "                y = label[(label == pred) & (advpred != label) & (label==pred_lst[j])]\n",
    "            \n",
    "                with torch.no_grad():\n",
    "                    for r in range((inputc.size(0) - 1) // 100 + 1):\n",
    "                        inputc_sub = inputc[r * 200: min((r + 1) * 200, inputc.size(0))]\n",
    "                        if len(inputc_sub) == 0:\n",
    "                            break\n",
    "                        y_sub = y[r * 200: min((r + 1) * 200, inputc.size(0))]\n",
    "                        __ = adv[j].predict(inputc_sub)\n",
    "                        output = (__).max(1, keepdim=False)[1]\n",
    "                        trans[i][j] += (output != y_sub).sum().item()\n",
    "                    trans[i][j] /= len(y)\n",
    "\n",
    "            print(i, j, trans[i][j])\n",
    "\n",
    "        new_list = [epoch]\n",
    "        for k in range(num_models):\n",
    "            new_list.append(trans[i,k])\n",
    "        new_list.append(accs[i])\n",
    "        new_tup = tuple(new_list)\n",
    "        trans_list.append(new_tup)\n",
    "\n",
    "    return trans_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function used for training each each epoch. By setting the ```lotos_flag``` to true, it uses LOTOS for training, otherwise it uses regular training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Naive_Trainer_ortho(loader: DataLoader, models, criterion, optimizer, epoch: int, device: torch.device, lotos_flag=False, \n",
    "                        catclip=False, no_effect_epochs=0, batch_counter=0, mal_freq=100, conv_freq=50, layer_1_only=False, \n",
    "                        conv_1st_only=False, lsv_list_dict={}, lsv_list_dict_conv={}, conv_only=False, conv_factor=0.05, cat_factor=0.05):\n",
    "    batch_time = AverageMeter()\n",
    "    data_time = AverageMeter()\n",
    "    losses = AverageMeter()\n",
    "    reg_losses = AverageMeter()\n",
    "    ortho_losses = AverageMeter()\n",
    "\n",
    "    end = time.time()\n",
    "    ortho_total = 0.\n",
    "\n",
    "    ortho_flag = lotos_flag\n",
    "    decrement = 0.01\n",
    "    weights = torch.from_numpy(np.array([1 - decrement*i for i in range(100)])).to(device)\n",
    "\n",
    "    print_freq = max(1000, 10*conv_freq)\n",
    "    if len(lsv_list_dict) == 0:\n",
    "        print('initiating lsv list dict')\n",
    "        for j in range(num_models):\n",
    "            lsv_list_dict[j] = None\n",
    "            lsv_list_dict_conv[j] = None\n",
    "\n",
    "    for i in range(num_models):\n",
    "        models[i].train()\n",
    "        requires_grad_(models[i], True)\n",
    "\n",
    "    cat_counter_info = 0\n",
    "    conv_counter_info = 0\n",
    "    for i, (inputs, targets) in enumerate(loader):\n",
    "        data_time.update(time.time() - end)\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        batch_size = inputs.size(0)\n",
    "        loss_std = 0\n",
    "        ortho_loss = 0\n",
    "        ortho_loss_conv = 0\n",
    "\n",
    "        for j in range(num_models):\n",
    "            logits = models[j](inputs)\n",
    "            loss = criterion(logits, targets)\n",
    "            loss_std += loss\n",
    "\n",
    "            if not ortho_flag:\n",
    "                continue\n",
    "            if i == len(loader)-1:\n",
    "                continue\n",
    "\n",
    "            VT_list = []\n",
    "            VT_list_conv = []\n",
    "            idx = 0\n",
    "\n",
    "            conv_count = 0\n",
    "            cat_counter_info = 0\n",
    "            conv_counter_info = 0\n",
    "\n",
    "            for (m_name, m) in models[j].named_modules():\n",
    "                condition = isinstance(m, (torch.nn.Conv2d))\n",
    "                condition_conv = isinstance(m, (torch.nn.Conv2d))\n",
    "                if not condition_conv and conv_only:\n",
    "                    continue\n",
    "                if catclip:\n",
    "                    condition = not conv_only and not isinstance(m, (torch.nn.Conv2d)) and (not isinstance(m, torch.nn.BatchNorm2d) and not isinstance(m, torch.nn.Linear))\n",
    "\n",
    "                if not condition_conv and epoch < no_effect_epochs:\n",
    "                    conv_factor = 0.0\n",
    "                    cat_factor = 0.0\n",
    "\n",
    "                if condition or condition_conv:\n",
    "                    attrs = vars(m)\n",
    "                    for item in attrs.items():\n",
    "                        if item[0] == '_buffers' and 'weight_VT' in item[1]:\n",
    "                            VT = item[1]['weight_VT']\n",
    "                            if batch_counter != 0:\n",
    "                                first_flag = True\n",
    "                                for k in range(num_models):\n",
    "                                    if k == j:\n",
    "                                        if condition_conv:\n",
    "                                            if batch_counter % print_freq != 0:\n",
    "                                                continue\n",
    "                                            prev_VT = lsv_list_dict_conv[k]\n",
    "                                            sing_vector = prev_VT[conv_count]\n",
    "                                        else:\n",
    "                                            if batch_counter % mal_freq != 0:\n",
    "                                                continue\n",
    "                                            prev_VT = lsv_list_dict[k]\n",
    "                                            sing_vector = prev_VT[idx]\n",
    "\n",
    "                                        sing_vector = torch.nn.parameter.Parameter(data=sing_vector, requires_grad=False)\n",
    "                                        op_shape = [i for i in range(1, len(sing_vector.shape))]\n",
    "                                        lsv_check = torch.sqrt(torch.sum(m(sing_vector) **2, axis=op_shape))/torch.sqrt(torch.sum(sing_vector **2, axis=op_shape)) \n",
    "                                        lsv_check_noBias = torch.sqrt(torch.sum( (m(sing_vector) - m(torch.zeros_like(sing_vector) ) )**2, axis=op_shape))/torch.sqrt(torch.sum(sing_vector **2, axis=op_shape)) \n",
    "                                        continue\n",
    "\n",
    "                                    if condition_conv:\n",
    "                                        prev_VT_list = lsv_list_dict_conv[k]\n",
    "                                    else:\n",
    "                                        prev_VT_list = lsv_list_dict[k]\n",
    "\n",
    "                                    if not condition_conv and layer_1_only and idx > 0:\n",
    "                                        continue\n",
    "                                    if condition_conv and conv_1st_only and conv_count > 0:\n",
    "                                        continue\n",
    "\n",
    "                                    if condition_conv:\n",
    "                                        bad_vector = prev_VT_list[conv_count]\n",
    "                                    else:\n",
    "                                        bad_vector = prev_VT_list[idx]\n",
    "\n",
    "                                    if batch_counter % mal_freq != 0 and not condition_conv:\n",
    "                                        continue\n",
    "                                    if batch_counter % conv_freq != 0 and condition_conv:\n",
    "                                        continue\n",
    "\n",
    "                                    bad_vector = torch.nn.parameter.Parameter(data=bad_vector, requires_grad=False)\n",
    "                                    op_shape = [i for i in range(1, len(bad_vector.shape))]\n",
    "                                    bad_vec_length = torch.sqrt(torch.sum((m(bad_vector) - m(torch.zeros_like(bad_vector)) )**2, axis=op_shape))/torch.sqrt(torch.sum(bad_vector **2, axis=op_shape)) ##### fix this shit for multiple vectors!                       \n",
    "\n",
    "                                    if condition_conv:\n",
    "                                        bad_vec_length_thresh = torch.nn.functional.relu(bad_vec_length-bottom_clip)\n",
    "                                    else:\n",
    "                                        bad_vec_length_thresh = torch.nn.functional.relu(bad_vec_length-cat_bottom_clip)\n",
    "\n",
    "                                    bad_vec_length_weighted = torch.sum(torch.mul(bad_vec_length_thresh, weights[:len(bad_vec_length_thresh)]))/torch.sum(weights[:len(bad_vec_length_thresh)])\n",
    "\n",
    "                                    if condition_conv:\n",
    "                                        ortho_loss_conv += conv_factor * bad_vec_length_weighted\n",
    "                                        conv_counter_info += 1\n",
    "                                    if not condition_conv:\n",
    "                                        ortho_loss += cat_factor * bad_vec_length_weighted\n",
    "                                        cat_counter_info += 1\n",
    "\n",
    "                            if condition_conv:\n",
    "                                VT_list_conv.append(VT.detach())\n",
    "                                conv_count += 1\n",
    "                            else:\n",
    "                                VT_list.append(VT.detach())\n",
    "                                idx += 1\n",
    "\n",
    "            lsv_list_dict_conv[j] = VT_list_conv\n",
    "            if not conv_only:\n",
    "                lsv_list_dict[j] = VT_list\n",
    "\n",
    "        reg_losses.update(loss_std.item(), batch_size)\n",
    "        loss = loss_std\n",
    "\n",
    "        if ortho_flag:\n",
    "            pair_count = num_models * (num_models - 1) / 2\n",
    "            conv_counter_info = conv_counter_info // (num_models-1)\n",
    "            cat_counter_info = cat_counter_info // (num_models-1)\n",
    "            \n",
    "            conv_normalizer = conv_counter_info*pair_count\n",
    "            cat_normalizer = cat_counter_info*pair_count\n",
    "\n",
    "            conv_normalizer = 1.\n",
    "            cat_normalizer = 1.\n",
    "\n",
    "            if ortho_loss_conv > 0 and (batch_counter % 200 != 199) and batch_counter > 0 and conv_counter_info > 0 and (batch_counter % conv_freq) == 0:\n",
    "                loss += ortho_loss_conv / conv_normalizer #/ (conv_counter_info*pair_count)\n",
    "                ortho_total += ortho_loss_conv.item() / conv_normalizer #/ (conv_counter_info*pair_count)\n",
    "                if batch_counter % print_freq == 0:\n",
    "                    print('pairs', pair_count,  'conv', conv_counter_info,  'ortho loss conv: ', ortho_loss_conv.item()/ conv_normalizer )\n",
    "\n",
    "            # if batch_counter % 200 != 199 and not conv_only and batch_counter > 0 and cat_counter_info > 0 and (batch_counter % mal_freq) == 50:\n",
    "            if not conv_only and batch_counter > 0 and cat_counter_info > 0 and (batch_counter % mal_freq) == 50:\n",
    "                loss += ortho_loss / cat_normalizer #/ (cat_counter_info*pair_count)\n",
    "                ortho_total += ortho_loss.item() / cat_normalizer #/ (cat_counter_info*pair_count)\n",
    "                if batch_counter % mal_freq == 0:\n",
    "                    print('pairs', pair_count, 'cat', cat_counter_info, 'ortho loss cat: ', ortho_loss.item() / cat_normalizer)\n",
    "\n",
    "        losses.update(loss.item(), batch_size)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # measure elapsed time\n",
    "        batch_time.update(time.time() - end)\n",
    "        end = time.time()\n",
    "        batch_counter += 1\n",
    "\n",
    "    print('Epoch: ', epoch, 'Loss: ', losses.avg, 'Loss_std: ', reg_losses.avg, 'Ortho_loss: ', ortho_losses.avg, 'ortho total:', ortho_total)\n",
    "\n",
    "    return losses.avg, batch_counter, lsv_list_dict, lsv_list_dict_conv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reading CIFAR-10 dataset and preparing data loaders:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "### Reading CIFAR10:\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train)\n",
    "train_loader = torch.utils.data.DataLoader( trainset, batch_size=128, shuffle=True, num_workers=1)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test)\n",
    "test_loader = torch.utils.data.DataLoader( testset, batch_size=128, shuffle=False, num_workers=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Function to perform the training on ensembles. For ```Orig``` ensembles, both ```clip_flag``` and ```lotos_flag``` should be set to False. For ```C=1```, the ```clip_flag``` should be set the True. For ```LOTOS```, both flags should be set to True:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_ensemble(num_models, clip_flag=False, lotos_flag=False, seed_val=0):\n",
    "    torch.manual_seed(seed_val)\n",
    "    torch.cuda.manual_seed_all(seed_val)\n",
    "    np.random.seed(seed_val)\n",
    "    random.seed(seed_val)\n",
    "\n",
    "    model = []\n",
    "    for i in range(num_models):\n",
    "        if clip_flag:\n",
    "            submodel = ResNet18(in_chan=in_chan, num_classes=num_classes, device=device, clip_flag=True, bn=False, clip_steps=clip_steps, writer=None)\n",
    "        else:\n",
    "            submodel = ResNet18_orig(in_chan=in_chan, num_classes=num_classes, bn=False, device=device)\n",
    "        submodel = nn.DataParallel(submodel)\n",
    "        model.append(submodel)\n",
    "    print(\"Model loaded\")\n",
    "\n",
    "    criterion = nn.CrossEntropyLoss().cuda()\n",
    "\n",
    "    param = list(model[0].parameters())\n",
    "    for i in range(1, num_models):\n",
    "        param.extend(list(model[i].parameters()))\n",
    "\n",
    "    optimizer = optim.SGD(param, lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
    "    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)\n",
    "\n",
    "    trans_list = []\n",
    "    batch_counter = 0\n",
    "    lsv_list_dict = {}\n",
    "    lsv_list_dict_conv = {}\n",
    "    for epoch in range(epochs):\n",
    "        start = time.time()\n",
    "        train_loss, batch_counter, lsv_list_dict, lsv_list_dict_conv = Naive_Trainer_ortho(train_loader, model, criterion, optimizer, \n",
    "                                                                                           epoch, device, lotos_flag=lotos_flag, catclip=False, no_effect_epochs=effective_epoch, \n",
    "                                                                                           batch_counter=batch_counter, mal_freq=freq, conv_freq=conv_freq, layer_1_only=False, \n",
    "                                                                                           conv_1st_only=conv_1st_only, lsv_list_dict=lsv_list_dict, lsv_list_dict_conv=lsv_list_dict_conv, \n",
    "                                                                                           conv_factor=conv_factor, cat_factor=cat_factor,\n",
    "                                                                                           conv_only=conv_only)\n",
    "        print('time: ', time.time() - start)\n",
    "\n",
    "        if epoch % 20 == 0 and epoch >= 60: \n",
    "            trans_list_new = evaltrans_correct(test_loader, model, criterion, epoch)\n",
    "            trans_list += trans_list_new\n",
    "\n",
    "        scheduler.step()\n",
    "\n",
    "    col_names = ['epoch']\n",
    "    for k in range(num_models):\n",
    "        col_names.append('t' + str(k))\n",
    "    col_names.append('acc')\n",
    "\n",
    "    df = pd.DataFrame(trans_list, columns=col_names) \n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now running the training methods for 3 cases: ```Orig```, ```C=1.0```, and ```LOTOS C=1.0 mal=0.8```:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded\n",
      "initiating lsv list dict\n",
      "Epoch:  0 Loss:  4.161355045471192 Loss_std:  4.161355045471192 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.58352780342102\n",
      "Epoch:  1 Loss:  3.545693136672974 Loss_std:  3.545693136672974 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.557411670684814\n",
      "Epoch:  2 Loss:  3.102024624786377 Loss_std:  3.102024624786377 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.703930139541626\n",
      "Epoch:  3 Loss:  2.747564366836548 Loss_std:  2.747564366836548 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.93328595161438\n",
      "Epoch:  4 Loss:  2.491742465057373 Loss_std:  2.491742465057373 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.96542477607727\n",
      "Epoch:  5 Loss:  2.2254140531921385 Loss_std:  2.2254140531921385 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.92059588432312\n",
      "Epoch:  6 Loss:  2.0448360148620606 Loss_std:  2.0448360148620606 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.97882604598999\n",
      "Epoch:  7 Loss:  1.862692717590332 Loss_std:  1.862692717590332 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.995654344558716\n",
      "Epoch:  8 Loss:  1.7104009420013428 Loss_std:  1.7104009420013428 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.035223960876465\n",
      "Epoch:  9 Loss:  1.5933418035888671 Loss_std:  1.5933418035888671 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.013548374176025\n",
      "Epoch:  10 Loss:  1.5189277465820312 Loss_std:  1.5189277465820312 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.096226930618286\n",
      "Epoch:  11 Loss:  1.4361217595672608 Loss_std:  1.4361217595672608 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.946436405181885\n",
      "Epoch:  12 Loss:  1.3605998261642456 Loss_std:  1.3605998261642456 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.77836465835571\n",
      "Epoch:  13 Loss:  1.3171176036834717 Loss_std:  1.3171176036834717 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.745774269104004\n",
      "Epoch:  14 Loss:  1.2640174899673462 Loss_std:  1.2640174899673462 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.87689518928528\n",
      "Epoch:  15 Loss:  1.2280123067474364 Loss_std:  1.2280123067474364 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.525652170181274\n",
      "Epoch:  16 Loss:  1.1940137855529784 Loss_std:  1.1940137855529784 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.60247015953064\n",
      "Epoch:  17 Loss:  1.1617285887145996 Loss_std:  1.1617285887145996 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.5027129650116\n",
      "Epoch:  18 Loss:  1.1394893115234375 Loss_std:  1.1394893115234375 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.557518005371094\n",
      "Epoch:  19 Loss:  1.1192133003997802 Loss_std:  1.1192133003997802 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.476762771606445\n",
      "Epoch:  20 Loss:  1.0719614544296265 Loss_std:  1.0719614544296265 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.45003914833069\n",
      "Epoch:  21 Loss:  1.080829122619629 Loss_std:  1.080829122619629 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.46037817001343\n",
      "Epoch:  22 Loss:  1.0564193648910523 Loss_std:  1.0564193648910523 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.442426443099976\n",
      "Epoch:  23 Loss:  1.032963249206543 Loss_std:  1.032963249206543 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.517460346221924\n",
      "Epoch:  24 Loss:  1.0219222017288208 Loss_std:  1.0219222017288208 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.435980796813965\n",
      "Epoch:  25 Loss:  1.0085052013015747 Loss_std:  1.0085052013015747 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.55070424079895\n",
      "Epoch:  26 Loss:  0.9883160879516601 Loss_std:  0.9883160879516601 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.43706202507019\n",
      "Epoch:  27 Loss:  0.9884413487243653 Loss_std:  0.9884413487243653 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.4173800945282\n",
      "Epoch:  28 Loss:  0.9896609563827514 Loss_std:  0.9896609563827514 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.43077754974365\n",
      "Epoch:  29 Loss:  0.9535386048698425 Loss_std:  0.9535386048698425 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.430877923965454\n",
      "Epoch:  30 Loss:  0.9446393131256103 Loss_std:  0.9446393131256103 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.53767704963684\n",
      "Epoch:  31 Loss:  0.9356701823806762 Loss_std:  0.9356701823806762 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.4440336227417\n",
      "Epoch:  32 Loss:  0.9504644769287109 Loss_std:  0.9504644769287109 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.521652698516846\n",
      "Epoch:  33 Loss:  0.920756837310791 Loss_std:  0.920756837310791 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.424336671829224\n",
      "Epoch:  34 Loss:  0.9200079004669189 Loss_std:  0.9200079004669189 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.424474000930786\n",
      "Epoch:  35 Loss:  0.89968614944458 Loss_std:  0.89968614944458 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  41.285314083099365\n",
      "Epoch:  36 Loss:  0.9024299261093139 Loss_std:  0.9024299261093139 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  41.927793979644775\n",
      "Epoch:  37 Loss:  0.895915977973938 Loss_std:  0.895915977973938 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.53885531425476\n",
      "Epoch:  38 Loss:  0.8822295738601684 Loss_std:  0.8822295738601684 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.41643023490906\n",
      "Epoch:  39 Loss:  0.8755537427139283 Loss_std:  0.8755537427139283 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.50546717643738\n",
      "Epoch:  40 Loss:  0.5115322870826721 Loss_std:  0.5115322870826721 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.42005705833435\n",
      "Epoch:  41 Loss:  0.4197578287410736 Loss_std:  0.4197578287410736 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.480597734451294\n",
      "Epoch:  42 Loss:  0.38518776956558226 Loss_std:  0.38518776956558226 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.48913764953613\n",
      "Epoch:  43 Loss:  0.3617219450187683 Loss_std:  0.3617219450187683 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.53076219558716\n",
      "Epoch:  44 Loss:  0.3393652365684509 Loss_std:  0.3393652365684509 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.63236331939697\n",
      "Epoch:  45 Loss:  0.32344858827590944 Loss_std:  0.32344858827590944 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.528428077697754\n",
      "Epoch:  46 Loss:  0.3082397258281708 Loss_std:  0.3082397258281708 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.66911292076111\n",
      "Epoch:  47 Loss:  0.29536478872299193 Loss_std:  0.29536478872299193 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.58341145515442\n",
      "Epoch:  48 Loss:  0.28365767727851865 Loss_std:  0.28365767727851865 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.57987904548645\n",
      "Epoch:  49 Loss:  0.2752990455341339 Loss_std:  0.2752990455341339 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.591219425201416\n",
      "Epoch:  50 Loss:  0.2671731155061722 Loss_std:  0.2671731155061722 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.60153770446777\n",
      "Epoch:  51 Loss:  0.2561947023200989 Loss_std:  0.2561947023200989 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.75539827346802\n",
      "Epoch:  52 Loss:  0.25003630960941314 Loss_std:  0.25003630960941314 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.73017168045044\n",
      "Epoch:  53 Loss:  0.24610885557174683 Loss_std:  0.24610885557174683 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.63389182090759\n",
      "Epoch:  54 Loss:  0.23910668659210205 Loss_std:  0.23910668659210205 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.643393754959106\n",
      "Epoch:  55 Loss:  0.23244191820144652 Loss_std:  0.23244191820144652 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.66779661178589\n",
      "Epoch:  56 Loss:  0.22779031145095824 Loss_std:  0.22779031145095824 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.64316487312317\n",
      "Epoch:  57 Loss:  0.22075133193969726 Loss_std:  0.22075133193969726 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.65918231010437\n",
      "Epoch:  58 Loss:  0.216165275888443 Loss_std:  0.216165275888443 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.78619360923767\n",
      "Epoch:  59 Loss:  0.21476261549949646 Loss_std:  0.21476261549949646 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.78046989440918\n",
      "Epoch:  60 Loss:  0.20869914726257324 Loss_std:  0.20869914726257324 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.66390872001648\n",
      "0 0 0.8229178188253512\n",
      "model:  0 7194  out of  10000\n",
      "0 1 0.8013622463163748\n",
      "model:  1 7253  out of  10000\n",
      "1 0 0.7937405211636565\n",
      "1 1 0.8311617212307019\n",
      "Epoch:  61 Loss:  0.20461271640300752 Loss_std:  0.20461271640300752 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.86274313926697\n",
      "Epoch:  62 Loss:  0.20621674251556396 Loss_std:  0.20621674251556396 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.95169425010681\n",
      "Epoch:  63 Loss:  0.19816621606826781 Loss_std:  0.19816621606826781 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.980406284332275\n",
      "Epoch:  64 Loss:  0.1932109046268463 Loss_std:  0.1932109046268463 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.83087372779846\n",
      "Epoch:  65 Loss:  0.1979705990755558 Loss_std:  0.1979705990755558 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.835004568099976\n",
      "Epoch:  66 Loss:  0.19157652341365813 Loss_std:  0.19157652341365813 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.83834958076477\n",
      "Epoch:  67 Loss:  0.18084211794376373 Loss_std:  0.18084211794376373 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.916046380996704\n",
      "Epoch:  68 Loss:  0.1914961391425133 Loss_std:  0.1914961391425133 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.78612756729126\n",
      "Epoch:  69 Loss:  0.18136364421367646 Loss_std:  0.18136364421367646 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.73892116546631\n",
      "Epoch:  70 Loss:  0.18135834797859193 Loss_std:  0.18135834797859193 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.76445198059082\n",
      "Epoch:  71 Loss:  0.18132388767719268 Loss_std:  0.18132388767719268 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.63132333755493\n",
      "Epoch:  72 Loss:  0.1778282821393013 Loss_std:  0.1778282821393013 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.62344813346863\n",
      "Epoch:  73 Loss:  0.18337069101810455 Loss_std:  0.18337069101810455 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.64125728607178\n",
      "Epoch:  74 Loss:  0.18174754853248595 Loss_std:  0.18174754853248595 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.60726857185364\n",
      "Epoch:  75 Loss:  0.1744025723361969 Loss_std:  0.1744025723361969 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.72846508026123\n",
      "Epoch:  76 Loss:  0.17039530079841614 Loss_std:  0.17039530079841614 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.610737323760986\n",
      "Epoch:  77 Loss:  0.16999358876228332 Loss_std:  0.16999358876228332 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.67780518531799\n",
      "Epoch:  78 Loss:  0.1648142089176178 Loss_std:  0.1648142089176178 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.69255709648132\n",
      "Epoch:  79 Loss:  0.17149478332996368 Loss_std:  0.17149478332996368 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.60296320915222\n",
      "Epoch:  80 Loss:  0.08680317085504533 Loss_std:  0.08680317085504533 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.69421911239624\n",
      "0 0 0.8134709397066812\n",
      "model:  0 7258  out of  10000\n",
      "0 1 0.8452741802149353\n",
      "model:  1 7498  out of  10000\n",
      "1 0 0.7832755401440384\n",
      "1 1 0.8394793926247288\n",
      "Epoch:  81 Loss:  0.0635719159245491 Loss_std:  0.0635719159245491 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.52898716926575\n",
      "Epoch:  82 Loss:  0.05704024207830429 Loss_std:  0.05704024207830429 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.661837577819824\n",
      "Epoch:  83 Loss:  0.051838845088481904 Loss_std:  0.051838845088481904 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.57950448989868\n",
      "Epoch:  84 Loss:  0.047410607083439825 Loss_std:  0.047410607083439825 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.68063759803772\n",
      "Epoch:  85 Loss:  0.04376999249219894 Loss_std:  0.04376999249219894 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.57516002655029\n",
      "Epoch:  86 Loss:  0.042712540754675864 Loss_std:  0.042712540754675864 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.590766191482544\n",
      "Epoch:  87 Loss:  0.04265688340544701 Loss_std:  0.04265688340544701 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.5765438079834\n",
      "Epoch:  88 Loss:  0.03770122833609581 Loss_std:  0.03770122833609581 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.572298765182495\n",
      "Epoch:  89 Loss:  0.03689147403240204 Loss_std:  0.03689147403240204 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.68733859062195\n",
      "Epoch:  90 Loss:  0.03762853614687919 Loss_std:  0.03762853614687919 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.60014033317566\n",
      "Epoch:  91 Loss:  0.03482256926357746 Loss_std:  0.03482256926357746 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.72102379798889\n",
      "Epoch:  92 Loss:  0.03425301469869912 Loss_std:  0.03425301469869912 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.5893349647522\n",
      "Epoch:  93 Loss:  0.03360223310083151 Loss_std:  0.03360223310083151 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.73547911643982\n",
      "Epoch:  94 Loss:  0.03036640510737896 Loss_std:  0.03036640510737896 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.61659598350525\n",
      "Epoch:  95 Loss:  0.029271773434579373 Loss_std:  0.029271773434579373 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.6106231212616\n",
      "Epoch:  96 Loss:  0.029234748888611793 Loss_std:  0.029234748888611793 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.735918283462524\n",
      "Epoch:  97 Loss:  0.029833041535168886 Loss_std:  0.029833041535168886 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.57906913757324\n",
      "Epoch:  98 Loss:  0.026617150537967683 Loss_std:  0.026617150537967683 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.582358598709106\n",
      "Epoch:  99 Loss:  0.02704890827268362 Loss_std:  0.02704890827268362 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.60776090621948\n",
      "Epoch:  100 Loss:  0.02634244772851467 Loss_std:  0.02634244772851467 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.71265959739685\n",
      "0 0 0.8279488291413704\n",
      "model:  0 7410  out of  10000\n",
      "0 1 0.8399460188933873\n",
      "model:  1 7497  out of  10000\n",
      "1 0 0.8221955448846205\n",
      "1 1 0.8372748969407681\n",
      "Epoch:  101 Loss:  0.0265846846306324 Loss_std:  0.0265846846306324 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.873791217803955\n",
      "Epoch:  102 Loss:  0.025124271444678308 Loss_std:  0.025124271444678308 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.884053468704224\n",
      "Epoch:  103 Loss:  0.024236032368540765 Loss_std:  0.024236032368540765 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.925156116485596\n",
      "Epoch:  104 Loss:  0.02281848882853985 Loss_std:  0.02281848882853985 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.0122287273407\n",
      "Epoch:  105 Loss:  0.02406702029079199 Loss_std:  0.02406702029079199 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.937674045562744\n",
      "Epoch:  106 Loss:  0.024839938998520373 Loss_std:  0.024839938998520373 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.935367584228516\n",
      "Epoch:  107 Loss:  0.021902689902186395 Loss_std:  0.021902689902186395 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.05766010284424\n",
      "Epoch:  108 Loss:  0.023231919558048247 Loss_std:  0.023231919558048247 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.961655616760254\n",
      "Epoch:  109 Loss:  0.022417283083200453 Loss_std:  0.022417283083200453 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.985919713974\n",
      "Epoch:  110 Loss:  0.022111390767991543 Loss_std:  0.022111390767991543 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.99000549316406\n",
      "Epoch:  111 Loss:  0.02204952717423439 Loss_std:  0.02204952717423439 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.0758011341095\n",
      "Epoch:  112 Loss:  0.023457384964227675 Loss_std:  0.023457384964227675 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.97906541824341\n",
      "Epoch:  113 Loss:  0.021554049358218908 Loss_std:  0.021554049358218908 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.995274782180786\n",
      "Epoch:  114 Loss:  0.02197903710871935 Loss_std:  0.02197903710871935 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.06655550003052\n",
      "Epoch:  115 Loss:  0.02181302911490202 Loss_std:  0.02181302911490202 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.96926164627075\n",
      "Epoch:  116 Loss:  0.021076167548298837 Loss_std:  0.021076167548298837 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.98505926132202\n",
      "Epoch:  117 Loss:  0.02119017274528742 Loss_std:  0.02119017274528742 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.961575746536255\n",
      "Epoch:  118 Loss:  0.02148345782637596 Loss_std:  0.02148345782637596 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.09492468833923\n",
      "Epoch:  119 Loss:  0.018434472149312497 Loss_std:  0.018434472149312497 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  42.98840379714966\n",
      "Epoch:  120 Loss:  0.017745198669433592 Loss_std:  0.017745198669433592 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  43.10299468040466\n",
      "0 0 0.8198892628379112\n",
      "model:  0 7350  out of  10000\n",
      "0 1 0.8444897959183674\n",
      "model:  1 7467  out of  10000\n",
      "1 0 0.8184009642426677\n",
      "1 1 0.8329539295392954\n"
     ]
    }
   ],
   "source": [
    "df_orig = train_ensemble(num_models, clip_flag=False, lotos_flag=False, seed_val=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "Model loaded\n",
      "initiating lsv list dict\n",
      "(1, 3, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512)\n",
      "(1, 3, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512)\n",
      "Epoch:  0 Loss:  4.180389265518189 Loss_std:  4.180389265518189 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.87830686569214\n",
      "Epoch:  1 Loss:  3.6663756079101564 Loss_std:  3.6663756079101564 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.62053656578064\n",
      "Epoch:  2 Loss:  3.2725829370117188 Loss_std:  3.2725829370117188 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.25988793373108\n",
      "Epoch:  3 Loss:  2.922439958496094 Loss_std:  2.922439958496094 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.752126693725586\n",
      "Epoch:  4 Loss:  2.650546341094971 Loss_std:  2.650546341094971 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.511630058288574\n",
      "Epoch:  5 Loss:  2.374340210876465 Loss_std:  2.374340210876465 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.651432037353516\n",
      "Epoch:  6 Loss:  2.2096485959625243 Loss_std:  2.2096485959625243 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.137065172195435\n",
      "Epoch:  7 Loss:  2.025381053466797 Loss_std:  2.025381053466797 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.34798860549927\n",
      "Epoch:  8 Loss:  1.885373081855774 Loss_std:  1.885373081855774 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.88641357421875\n",
      "Epoch:  9 Loss:  1.7707129062652587 Loss_std:  1.7707129062652587 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.693052530288696\n",
      "Epoch:  10 Loss:  1.6786228678894044 Loss_std:  1.6786228678894044 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.02962255477905\n",
      "Epoch:  11 Loss:  1.589192261619568 Loss_std:  1.589192261619568 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.815672397613525\n",
      "Epoch:  12 Loss:  1.5173029791259767 Loss_std:  1.5173029791259767 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.47355771064758\n",
      "Epoch:  13 Loss:  1.4725828413391113 Loss_std:  1.4725828413391113 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.8296263217926\n",
      "Epoch:  14 Loss:  1.4008786275482177 Loss_std:  1.4008786275482177 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.61164832115173\n",
      "Epoch:  15 Loss:  1.3774346772766113 Loss_std:  1.3774346772766113 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.226953983306885\n",
      "Epoch:  16 Loss:  1.3521957118988037 Loss_std:  1.3521957118988037 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.16164422035217\n",
      "Epoch:  17 Loss:  1.3336308197021485 Loss_std:  1.3336308197021485 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.34267330169678\n",
      "Epoch:  18 Loss:  1.2793305777740478 Loss_std:  1.2793305777740478 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.11543607711792\n",
      "Epoch:  19 Loss:  1.2642519131469727 Loss_std:  1.2642519131469727 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.06529140472412\n",
      "Epoch:  20 Loss:  1.2362165296173095 Loss_std:  1.2362165296173095 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.958879709243774\n",
      "Epoch:  21 Loss:  1.2185683418273925 Loss_std:  1.2185683418273925 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.77850866317749\n",
      "Epoch:  22 Loss:  1.2201054372024536 Loss_std:  1.2201054372024536 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.5645112991333\n",
      "Epoch:  23 Loss:  1.1873366444206237 Loss_std:  1.1873366444206237 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.74406838417053\n",
      "Epoch:  24 Loss:  1.1701759261322022 Loss_std:  1.1701759261322022 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.120356798172\n",
      "Epoch:  25 Loss:  1.141319479484558 Loss_std:  1.141319479484558 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.885268211364746\n",
      "Epoch:  26 Loss:  1.1356819524383546 Loss_std:  1.1356819524383546 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.268046379089355\n",
      "Epoch:  27 Loss:  1.1382049885177612 Loss_std:  1.1382049885177612 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.06837749481201\n",
      "Epoch:  28 Loss:  1.1218011653137208 Loss_std:  1.1218011653137208 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.83095908164978\n",
      "Epoch:  29 Loss:  1.117606410369873 Loss_std:  1.117606410369873 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.91104483604431\n",
      "Epoch:  30 Loss:  1.0854231718063354 Loss_std:  1.0854231718063354 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.4574978351593\n",
      "Epoch:  31 Loss:  1.086968031349182 Loss_std:  1.086968031349182 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.867822885513306\n",
      "Epoch:  32 Loss:  1.1071453704071046 Loss_std:  1.1071453704071046 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.591575384140015\n",
      "Epoch:  33 Loss:  1.0756062436676026 Loss_std:  1.0756062436676026 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  56.807642221450806\n",
      "Epoch:  34 Loss:  1.0461039889526367 Loss_std:  1.0461039889526367 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.43059062957764\n",
      "Epoch:  35 Loss:  1.0411066751480103 Loss_std:  1.0411066751480103 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.69830274581909\n",
      "Epoch:  36 Loss:  1.054400485458374 Loss_std:  1.054400485458374 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.93140912055969\n",
      "Epoch:  37 Loss:  1.054210004272461 Loss_std:  1.054210004272461 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.72170925140381\n",
      "Epoch:  38 Loss:  1.0316430490875244 Loss_std:  1.0316430490875244 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.781636476516724\n",
      "Epoch:  39 Loss:  1.0106574132537842 Loss_std:  1.0106574132537842 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.58736228942871\n",
      "Epoch:  40 Loss:  0.5870960460281373 Loss_std:  0.5870960460281373 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  57.68207859992981\n",
      "Epoch:  41 Loss:  0.4934381718444824 Loss_std:  0.4934381718444824 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.79116177558899\n",
      "Epoch:  42 Loss:  0.463774265832901 Loss_std:  0.463774265832901 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.48008847236633\n",
      "Epoch:  43 Loss:  0.4558922172546387 Loss_std:  0.4558922172546387 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.96838927268982\n",
      "Epoch:  44 Loss:  0.43064164306640623 Loss_std:  0.43064164306640623 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.15360641479492\n",
      "Epoch:  45 Loss:  0.4282535615158081 Loss_std:  0.4282535615158081 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.845362424850464\n",
      "Epoch:  46 Loss:  0.4169449020385742 Loss_std:  0.4169449020385742 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.491010427474976\n",
      "Epoch:  47 Loss:  0.4144841004562378 Loss_std:  0.4144841004562378 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.45522332191467\n",
      "Epoch:  48 Loss:  0.4067897211837769 Loss_std:  0.4067897211837769 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.680909872055054\n",
      "Epoch:  49 Loss:  0.3981354513168335 Loss_std:  0.3981354513168335 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.704978942871094\n",
      "Epoch:  50 Loss:  0.3884531519508362 Loss_std:  0.3884531519508362 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.520243644714355\n",
      "Epoch:  51 Loss:  0.3885910443496704 Loss_std:  0.3885910443496704 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.17863321304321\n",
      "Epoch:  52 Loss:  0.37861841722488404 Loss_std:  0.37861841722488404 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.99867343902588\n",
      "Epoch:  53 Loss:  0.38126842291355134 Loss_std:  0.38126842291355134 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.936179399490356\n",
      "Epoch:  54 Loss:  0.37318453435897825 Loss_std:  0.37318453435897825 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.13222670555115\n",
      "Epoch:  55 Loss:  0.3618897628974915 Loss_std:  0.3618897628974915 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.562267541885376\n",
      "Epoch:  56 Loss:  0.3589271186065674 Loss_std:  0.3589271186065674 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.26020550727844\n",
      "Epoch:  57 Loss:  0.3585422385978699 Loss_std:  0.3585422385978699 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.916260957717896\n",
      "Epoch:  58 Loss:  0.35303465045928956 Loss_std:  0.35303465045928956 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.130536794662476\n",
      "Epoch:  59 Loss:  0.34264353132247927 Loss_std:  0.34264353132247927 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.878350257873535\n",
      "Epoch:  60 Loss:  0.34445993463516233 Loss_std:  0.34445993463516233 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.024266719818115\n",
      "0 0 0.7379424778761062\n",
      "model:  0 6389  out of  10000\n",
      "0 1 0.8411331976835186\n",
      "model:  1 6496  out of  10000\n",
      "1 0 0.8201970443349754\n",
      "1 1 0.7502774694783574\n",
      "Epoch:  61 Loss:  0.33623211168289185 Loss_std:  0.33623211168289185 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  57.76126551628113\n",
      "Epoch:  62 Loss:  0.33308582107543944 Loss_std:  0.33308582107543944 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.53254818916321\n",
      "Epoch:  63 Loss:  0.3151042035675049 Loss_std:  0.3151042035675049 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.70335507392883\n",
      "Epoch:  64 Loss:  0.32341262837409973 Loss_std:  0.32341262837409973 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.6667697429657\n",
      "Epoch:  65 Loss:  0.3203533421230316 Loss_std:  0.3203533421230316 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.02855706214905\n",
      "Epoch:  66 Loss:  0.31945226957321166 Loss_std:  0.31945226957321166 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.70489573478699\n",
      "Epoch:  67 Loss:  0.3091120443153381 Loss_std:  0.3091120443153381 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.44179701805115\n",
      "Epoch:  68 Loss:  0.3099606702613831 Loss_std:  0.3099606702613831 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.380866289138794\n",
      "Epoch:  69 Loss:  0.2992178147315979 Loss_std:  0.2992178147315979 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.73953032493591\n",
      "Epoch:  70 Loss:  0.30510905151367185 Loss_std:  0.30510905151367185 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.70459532737732\n",
      "Epoch:  71 Loss:  0.30010727605819704 Loss_std:  0.30010727605819704 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.43834066390991\n",
      "Epoch:  72 Loss:  0.28994310033798215 Loss_std:  0.28994310033798215 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  62.27299213409424\n",
      "Epoch:  73 Loss:  0.2979374610519409 Loss_std:  0.2979374610519409 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.684131383895874\n",
      "Epoch:  74 Loss:  0.29107199764251707 Loss_std:  0.29107199764251707 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.51157283782959\n",
      "Epoch:  75 Loss:  0.2772536003708839 Loss_std:  0.2772536003708839 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.79025411605835\n",
      "Epoch:  76 Loss:  0.27560978708267214 Loss_std:  0.27560978708267214 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.42483854293823\n",
      "Epoch:  77 Loss:  0.27512218242645264 Loss_std:  0.27512218242645264 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.38979458808899\n",
      "Epoch:  78 Loss:  0.2674104031181335 Loss_std:  0.2674104031181335 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.63501811027527\n",
      "Epoch:  79 Loss:  0.2641879362869263 Loss_std:  0.2641879362869263 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.581525802612305\n",
      "Epoch:  80 Loss:  0.15793841863632202 Loss_std:  0.15793841863632202 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.349026679992676\n",
      "0 0 0.7316408803660929\n",
      "model:  0 6528  out of  10000\n",
      "0 1 0.8426776960784313\n",
      "model:  1 6518  out of  10000\n",
      "1 0 0.8432034366370053\n",
      "1 1 0.7316436419014467\n",
      "Epoch:  81 Loss:  0.1436445176410675 Loss_std:  0.1436445176410675 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.979647636413574\n",
      "Epoch:  82 Loss:  0.14833731568336486 Loss_std:  0.14833731568336486 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.47687554359436\n",
      "Epoch:  83 Loss:  0.15416006804466248 Loss_std:  0.15416006804466248 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.20858025550842\n",
      "Epoch:  84 Loss:  0.1615983250951767 Loss_std:  0.1615983250951767 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.54197144508362\n",
      "Epoch:  85 Loss:  0.16565737963676452 Loss_std:  0.16565737963676452 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.4595422744751\n",
      "Epoch:  86 Loss:  0.17945342081069945 Loss_std:  0.17945342081069945 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.67027449607849\n",
      "Epoch:  87 Loss:  0.1872028828716278 Loss_std:  0.1872028828716278 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.4395866394043\n",
      "Epoch:  88 Loss:  0.19852157256603242 Loss_std:  0.19852157256603242 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  59.05012798309326\n",
      "Epoch:  89 Loss:  0.2069710892057419 Loss_std:  0.2069710892057419 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.544514179229736\n",
      "Epoch:  90 Loss:  0.20926155185699463 Loss_std:  0.20926155185699463 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.479891538619995\n",
      "Epoch:  91 Loss:  0.2209977927494049 Loss_std:  0.2209977927494049 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.70980620384216\n",
      "Epoch:  92 Loss:  0.2260087097120285 Loss_std:  0.2260087097120285 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.03909635543823\n",
      "Epoch:  93 Loss:  0.23409772639751433 Loss_std:  0.23409772639751433 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.90102028846741\n",
      "Epoch:  94 Loss:  0.2394346473455429 Loss_std:  0.2394346473455429 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.21434664726257\n",
      "Epoch:  95 Loss:  0.23825438638687133 Loss_std:  0.23825438638687133 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.56713676452637\n",
      "Epoch:  96 Loss:  0.2449680916595459 Loss_std:  0.2449680916595459 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.46127915382385\n",
      "Epoch:  97 Loss:  0.2503879780673981 Loss_std:  0.2503879780673981 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.84310507774353\n",
      "Epoch:  98 Loss:  0.2539322603225708 Loss_std:  0.2539322603225708 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.36397910118103\n",
      "Epoch:  99 Loss:  0.2528915721511841 Loss_std:  0.2528915721511841 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.72047424316406\n",
      "Epoch:  100 Loss:  0.2560807319831848 Loss_std:  0.2560807319831848 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.788373708724976\n",
      "0 0 0.6840650226694681\n",
      "model:  0 6008  out of  10000\n",
      "0 1 0.8558588548601864\n",
      "model:  1 5862  out of  10000\n",
      "1 0 0.885022176731491\n",
      "1 1 0.6706414927680248\n",
      "Epoch:  101 Loss:  0.25567502818107607 Loss_std:  0.25567502818107607 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.85138392448425\n",
      "Epoch:  102 Loss:  0.2598476601028442 Loss_std:  0.2598476601028442 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.931172132492065\n",
      "Epoch:  103 Loss:  0.2623800147151947 Loss_std:  0.2623800147151947 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.8173406124115\n",
      "Epoch:  104 Loss:  0.25626810124874116 Loss_std:  0.25626810124874116 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.95087456703186\n",
      "Epoch:  105 Loss:  0.25708955739974976 Loss_std:  0.25708955739974976 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  58.981910943984985\n",
      "Epoch:  106 Loss:  0.2589504161453247 Loss_std:  0.2589504161453247 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.56222653388977\n",
      "Epoch:  107 Loss:  0.26247314723968507 Loss_std:  0.26247314723968507 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.14615321159363\n",
      "Epoch:  108 Loss:  0.26247782257080077 Loss_std:  0.26247782257080077 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.62514114379883\n",
      "Epoch:  109 Loss:  0.26311412437438964 Loss_std:  0.26311412437438964 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.414064168930054\n",
      "Epoch:  110 Loss:  0.2655480202960968 Loss_std:  0.2655480202960968 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.64467144012451\n",
      "Epoch:  111 Loss:  0.2611598613452911 Loss_std:  0.2611598613452911 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.276366233825684\n",
      "Epoch:  112 Loss:  0.2672363122367859 Loss_std:  0.2672363122367859 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.0160653591156\n",
      "Epoch:  113 Loss:  0.26614490203619 Loss_std:  0.26614490203619 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.08387732505798\n",
      "Epoch:  114 Loss:  0.26284752948760987 Loss_std:  0.26284752948760987 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.48948049545288\n",
      "Epoch:  115 Loss:  0.26291532200813295 Loss_std:  0.26291532200813295 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.5800359249115\n",
      "Epoch:  116 Loss:  0.26402881258010863 Loss_std:  0.26402881258010863 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  61.39694809913635\n",
      "Epoch:  117 Loss:  0.2620609988021851 Loss_std:  0.2620609988021851 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.71695637702942\n",
      "Epoch:  118 Loss:  0.26396001113891604 Loss_std:  0.26396001113891604 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.70814371109009\n",
      "Epoch:  119 Loss:  0.25697466289520265 Loss_std:  0.25697466289520265 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.14555096626282\n",
      "Epoch:  120 Loss:  0.2427034397792816 Loss_std:  0.2427034397792816 Ortho_loss:  0 ortho total: 0.0\n",
      "time:  60.00899934768677\n",
      "0 0 0.656305114638448\n",
      "model:  0 5767  out of  10000\n",
      "0 1 0.8829547425004335\n",
      "model:  1 5881  out of  10000\n",
      "1 0 0.8595476959700731\n",
      "1 1 0.6703042328042328\n"
     ]
    }
   ],
   "source": [
    "df_clip = train_ensemble(num_models, clip_flag=True, lotos_flag=False, seed_val=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "!!!!!!! Clipping is active !!!!!!!! clip val:  1.0\n",
      "def iter:  1\n",
      "opt step size:  0.35\n",
      "Model loaded\n",
      "initiating lsv list dict\n",
      "(1, 3, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512)\n",
      "(1, 3, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 64, 32, 32)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 128, 16, 16)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 256, 8, 8)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512, 4, 4)\n",
      "(1, 512)\n",
      "Epoch:  0 Loss:  4.331428911590576 Loss_std:  4.30791587387085 Ortho_loss:  0 ortho total: 9.184778550267222\n",
      "time:  81.22298336029053\n",
      "Epoch:  1 Loss:  3.855189222564697 Loss_std:  3.8344614894104003 Ortho_loss:  0 ortho total: 8.096771186590189\n",
      "time:  75.01001715660095\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.06294432282447816\n",
      "Epoch:  2 Loss:  3.4638632223510744 Loss_std:  3.4522372518920896 Ortho_loss:  0 ortho total: 4.54139577448368\n",
      "time:  74.58823227882385\n",
      "Epoch:  3 Loss:  3.1080672568511964 Loss_std:  3.0919850583648683 Ortho_loss:  0 ortho total: 6.2821093589067525\n",
      "time:  71.48049354553223\n",
      "Epoch:  4 Loss:  2.778277930984497 Loss_std:  2.7688508667755127 Ortho_loss:  0 ortho total: 3.682446408271789\n",
      "time:  71.88268780708313\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.14125147461891174\n",
      "Epoch:  5 Loss:  2.495051594085693 Loss_std:  2.484294691925049 Ortho_loss:  0 ortho total: 4.201913446187972\n",
      "time:  71.20928692817688\n",
      "Epoch:  6 Loss:  2.2760241912078856 Loss_std:  2.258571792526245 Ortho_loss:  0 ortho total: 6.8173434734344465\n",
      "time:  71.78639197349548\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.134126028418541\n",
      "Epoch:  7 Loss:  2.0787868640136717 Loss_std:  2.058934369506836 Ortho_loss:  0 ortho total: 7.754881906509402\n",
      "time:  73.45573687553406\n",
      "Epoch:  8 Loss:  1.919315755996704 Loss_std:  1.8988805875396728 Ortho_loss:  0 ortho total: 7.982487177848811\n",
      "time:  73.49940490722656\n",
      "Epoch:  9 Loss:  1.7849530498504638 Loss_std:  1.770169524154663 Ortho_loss:  0 ortho total: 5.774813744425776\n",
      "time:  77.52485036849976\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.11057372093200685\n",
      "Epoch:  10 Loss:  1.704808023071289 Loss_std:  1.6897850900268554 Ortho_loss:  0 ortho total: 5.868333789706227\n",
      "time:  78.8165123462677\n",
      "Epoch:  11 Loss:  1.6257781272888183 Loss_std:  1.6078932643127442 Ortho_loss:  0 ortho total: 6.986274060606961\n",
      "time:  73.96356153488159\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.1471812963485718\n",
      "Epoch:  12 Loss:  1.5503138889312744 Loss_std:  1.5292838988494872 Ortho_loss:  0 ortho total: 8.21483996510506\n",
      "time:  76.3677146434784\n",
      "Epoch:  13 Loss:  1.5015919664764403 Loss_std:  1.4836114916229248 Ortho_loss:  0 ortho total: 7.023622608184818\n",
      "time:  74.22441482543945\n",
      "Epoch:  14 Loss:  1.4743177335357667 Loss_std:  1.4469558469390869 Ortho_loss:  0 ortho total: 10.688237398862848\n",
      "time:  73.84966278076172\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.14466152787208558\n",
      "Epoch:  15 Loss:  1.4461554804611205 Loss_std:  1.4231396754074097 Ortho_loss:  0 ortho total: 8.99054881036282\n",
      "time:  73.28329205513\n",
      "Epoch:  16 Loss:  1.4037282023620605 Loss_std:  1.385325799407959 Ortho_loss:  0 ortho total: 7.188438361883166\n",
      "time:  72.11008763313293\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.13728268444538116\n",
      "Epoch:  17 Loss:  1.3597224082183839 Loss_std:  1.3374477425384521 Ortho_loss:  0 ortho total: 8.701040685176837\n",
      "time:  72.20902824401855\n",
      "Epoch:  18 Loss:  1.3408521501922608 Loss_std:  1.319550103225708 Ortho_loss:  0 ortho total: 8.321111184358585\n",
      "time:  72.4146077632904\n",
      "Epoch:  19 Loss:  1.3380710703277587 Loss_std:  1.3164598708343507 Ortho_loss:  0 ortho total: 8.441874077916138\n",
      "time:  73.36146092414856\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.14471341073513033\n",
      "Epoch:  20 Loss:  1.2859454647064208 Loss_std:  1.2650964775848388 Ortho_loss:  0 ortho total: 8.144136193394662\n",
      "time:  74.68603754043579\n",
      "Epoch:  21 Loss:  1.281507181930542 Loss_std:  1.2611827103424071 Ortho_loss:  0 ortho total: 7.939247643947602\n",
      "time:  77.19296026229858\n",
      "Epoch:  22 Loss:  1.2739876969909667 Loss_std:  1.2509648622131349 Ortho_loss:  0 ortho total: 8.993295583128935\n",
      "time:  74.58003664016724\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.11846583783626558\n",
      "Epoch:  23 Loss:  1.262002939682007 Loss_std:  1.241511625442505 Ortho_loss:  0 ortho total: 8.004420644044874\n",
      "time:  75.25158858299255\n",
      "Epoch:  24 Loss:  1.2581999604415894 Loss_std:  1.2362179078292848 Ortho_loss:  0 ortho total: 8.586738950014126\n",
      "time:  73.50807428359985\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.12182242274284363\n",
      "Epoch:  25 Loss:  1.2204307902526856 Loss_std:  1.1964755438232422 Ortho_loss:  0 ortho total: 9.35751708745957\n",
      "time:  74.12567448616028\n",
      "Epoch:  26 Loss:  1.2474780373001098 Loss_std:  1.22070527759552 Ortho_loss:  0 ortho total: 10.458109015226364\n",
      "time:  76.4134669303894\n",
      "Epoch:  27 Loss:  1.2278141609954834 Loss_std:  1.2049318935394286 Ortho_loss:  0 ortho total: 8.938385209441185\n",
      "time:  76.2867968082428\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.09864979684352876\n",
      "Epoch:  28 Loss:  1.212348731994629 Loss_std:  1.1895085221862793 Ortho_loss:  0 ortho total: 8.921957942843443\n",
      "time:  75.7971363067627\n",
      "Epoch:  29 Loss:  1.1989577058792114 Loss_std:  1.17999533908844 Ortho_loss:  0 ortho total: 7.407174068689345\n",
      "time:  76.90994381904602\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.12644104957580568\n",
      "Epoch:  30 Loss:  1.173960670185089 Loss_std:  1.1527275947761535 Ortho_loss:  0 ortho total: 8.294169405102735\n",
      "time:  76.31301164627075\n",
      "Epoch:  31 Loss:  1.1750450562667847 Loss_std:  1.1548196804428101 Ortho_loss:  0 ortho total: 7.900538280606271\n",
      "time:  74.7634756565094\n",
      "Epoch:  32 Loss:  1.1601181720352174 Loss_std:  1.14099762134552 Ortho_loss:  0 ortho total: 7.468964329361913\n",
      "time:  75.76122808456421\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.13346891403198244\n",
      "Epoch:  33 Loss:  1.155918051071167 Loss_std:  1.1393144205474854 Ortho_loss:  0 ortho total: 6.485793715715411\n",
      "time:  77.56212043762207\n",
      "Epoch:  34 Loss:  1.1246073973846435 Loss_std:  1.1037831020355224 Ortho_loss:  0 ortho total: 8.134490737318995\n",
      "time:  77.80403184890747\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.12233560979366304\n",
      "Epoch:  35 Loss:  1.1341837409973146 Loss_std:  1.1126592080688475 Ortho_loss:  0 ortho total: 8.408021044731138\n",
      "time:  79.7057433128357\n",
      "Epoch:  36 Loss:  1.1227170432662963 Loss_std:  1.1065611347579956 Ortho_loss:  0 ortho total: 6.310902673006061\n",
      "time:  72.92620897293091\n",
      "Epoch:  37 Loss:  1.1201819271850586 Loss_std:  1.0997883726501465 Ortho_loss:  0 ortho total: 7.9662323087453855\n",
      "time:  74.60076570510864\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.14660874009132385\n",
      "Epoch:  38 Loss:  1.1024219323348998 Loss_std:  1.0845694985580445 Ortho_loss:  0 ortho total: 6.973606917262077\n",
      "time:  74.39808130264282\n",
      "Epoch:  39 Loss:  1.108846252822876 Loss_std:  1.0887972577667235 Ortho_loss:  0 ortho total: 7.831637486815451\n",
      "time:  75.37370276451111\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.04331260621547699\n",
      "Epoch:  40 Loss:  0.6487241990375519 Loss_std:  0.6355277679920196 Ortho_loss:  0 ortho total: 5.15485568344593\n",
      "time:  74.41674757003784\n",
      "Epoch:  41 Loss:  0.5565885959625244 Loss_std:  0.5434876020050049 Ortho_loss:  0 ortho total: 5.117575678229331\n",
      "time:  76.09761357307434\n",
      "Epoch:  42 Loss:  0.5183228420448304 Loss_std:  0.509169671382904 Ortho_loss:  0 ortho total: 3.575457391142846\n",
      "time:  76.30989789962769\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.021151864528656007\n",
      "Epoch:  43 Loss:  0.501616706314087 Loss_std:  0.4961108657073975 Ortho_loss:  0 ortho total: 2.1507192581892016\n",
      "time:  73.89349174499512\n",
      "Epoch:  44 Loss:  0.48113945838928224 Loss_std:  0.4762758034133911 Ortho_loss:  0 ortho total: 1.899864941835405\n",
      "time:  70.33972430229187\n",
      "Epoch:  45 Loss:  0.46988483840942386 Loss_std:  0.4652651537704468 Ortho_loss:  0 ortho total: 1.8045643359422694\n",
      "time:  72.86867785453796\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.012269952893257143\n",
      "Epoch:  46 Loss:  0.45935480237960813 Loss_std:  0.456565561542511 Ortho_loss:  0 ortho total: 1.0895472496747975\n",
      "time:  71.2244815826416\n",
      "Epoch:  47 Loss:  0.45383058692932127 Loss_std:  0.45151901565551755 Ortho_loss:  0 ortho total: 0.9029574543237685\n",
      "time:  73.44878125190735\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.01511407494544983\n",
      "Epoch:  48 Loss:  0.43914314779281616 Loss_std:  0.4350233550453186 Ortho_loss:  0 ortho total: 1.6092940062284458\n",
      "time:  74.10432696342468\n",
      "Epoch:  49 Loss:  0.43241416973114016 Loss_std:  0.42769224117279053 Ortho_loss:  0 ortho total: 1.8445030897855756\n",
      "time:  76.25927305221558\n",
      "Epoch:  50 Loss:  0.42597715314865114 Loss_std:  0.4210105285835266 Ortho_loss:  0 ortho total: 1.9400875508785262\n",
      "time:  73.21324825286865\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.02037807106971741\n",
      "Epoch:  51 Loss:  0.4227888200759888 Loss_std:  0.41908532176971436 Ortho_loss:  0 ortho total: 1.4466788619756705\n",
      "time:  73.48598957061768\n",
      "Epoch:  52 Loss:  0.41368230228424074 Loss_std:  0.4104650142288208 Ortho_loss:  0 ortho total: 1.2567529559135435\n",
      "time:  74.74709725379944\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.021896290779113772\n",
      "Epoch:  53 Loss:  0.40928556420326234 Loss_std:  0.4050532726764679 Ortho_loss:  0 ortho total: 1.6532390803098684\n",
      "time:  73.37075185775757\n",
      "Epoch:  54 Loss:  0.4014022096061707 Loss_std:  0.39786967630386355 Ortho_loss:  0 ortho total: 1.3798958063125613\n",
      "time:  71.03226923942566\n",
      "Epoch:  55 Loss:  0.39437196184158324 Loss_std:  0.38927006635665895 Ortho_loss:  0 ortho total: 1.9929280489683145\n",
      "time:  75.12351870536804\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.011876329779624939\n",
      "Epoch:  56 Loss:  0.38904742729187014 Loss_std:  0.38701096313476563 Ortho_loss:  0 ortho total: 0.795493695139885\n",
      "time:  76.0308723449707\n",
      "Epoch:  57 Loss:  0.3810337649536133 Loss_std:  0.37631457439422605 Ortho_loss:  0 ortho total: 1.8434337317943577\n",
      "time:  74.95878434181213\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.01638985574245453\n",
      "Epoch:  58 Loss:  0.3770767170238495 Loss_std:  0.373359978017807 Ortho_loss:  0 ortho total: 1.4518509894609448\n",
      "time:  77.56494331359863\n",
      "Epoch:  59 Loss:  0.3700493775367737 Loss_std:  0.36400679973602296 Ortho_loss:  0 ortho total: 2.3603819280862806\n",
      "time:  77.82520723342896\n",
      "Epoch:  60 Loss:  0.37347948955535887 Loss_std:  0.3692995154380798 Ortho_loss:  0 ortho total: 1.6328021734952922\n",
      "time:  76.88828730583191\n",
      "0 0 0.7718690649089487\n",
      "model:  0 6563  out of  10000\n",
      "0 1 0.6643303367362486\n",
      "model:  1 6750  out of  10000\n",
      "1 0 0.5976296296296296\n",
      "1 1 0.7913733378031065\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.02208298146724701\n",
      "Epoch:  61 Loss:  0.3655836133956909 Loss_std:  0.36251178352355956 Ortho_loss:  0 ortho total: 1.1999337524175655\n",
      "time:  76.89503407478333\n",
      "Epoch:  62 Loss:  0.3533384543895721 Loss_std:  0.3489787203121185 Ortho_loss:  0 ortho total: 1.7030210554599772\n",
      "time:  73.91574692726135\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.02538833320140839\n",
      "Epoch:  63 Loss:  0.34235117469787596 Loss_std:  0.3380826026916504 Ortho_loss:  0 ortho total: 1.6674109041690826\n",
      "time:  73.8852927684784\n",
      "Epoch:  64 Loss:  0.35150682929039 Loss_std:  0.34599912329673765 Ortho_loss:  0 ortho total: 2.1514476418495185\n",
      "time:  74.9897186756134\n",
      "Epoch:  65 Loss:  0.3427953746080399 Loss_std:  0.3389295674943924 Ortho_loss:  0 ortho total: 1.5100809156894677\n",
      "time:  75.73641753196716\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.04787673354148866\n",
      "Epoch:  66 Loss:  0.34566131167411807 Loss_std:  0.33999869545936584 Ortho_loss:  0 ortho total: 2.2119595646858214\n",
      "time:  73.98017454147339\n",
      "Epoch:  67 Loss:  0.33526055275917055 Loss_std:  0.3293483132648468 Ortho_loss:  0 ortho total: 2.30946865081787\n",
      "time:  74.52979230880737\n",
      "Epoch:  68 Loss:  0.3347396578502655 Loss_std:  0.32889648270606997 Ortho_loss:  0 ortho total: 2.282490143179894\n",
      "time:  72.79305529594421\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.03401496708393097\n",
      "Epoch:  69 Loss:  0.32666585157394407 Loss_std:  0.3215663397026062 Ortho_loss:  0 ortho total: 1.9919965058565128\n",
      "time:  70.8165054321289\n",
      "Epoch:  70 Loss:  0.33127708904266356 Loss_std:  0.3274371641159058 Ortho_loss:  0 ortho total: 1.4999706238508235\n",
      "time:  73.04203963279724\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.024848881363868716\n",
      "Epoch:  71 Loss:  0.3195217065143585 Loss_std:  0.31445727963447573 Ortho_loss:  0 ortho total: 1.9782917112112053\n",
      "time:  73.1978166103363\n",
      "Epoch:  72 Loss:  0.31497862730503084 Loss_std:  0.3094388375711441 Ortho_loss:  0 ortho total: 2.163980436325072\n",
      "time:  74.11616158485413\n",
      "Epoch:  73 Loss:  0.3201652805995941 Loss_std:  0.313745152463913 Ortho_loss:  0 ortho total: 2.5078625798225413\n",
      "time:  74.83897399902344\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.01938474476337433\n",
      "Epoch:  74 Loss:  0.3069514772796631 Loss_std:  0.3035901818466187 Ortho_loss:  0 ortho total: 1.3130063235759737\n",
      "time:  74.11206674575806\n",
      "Epoch:  75 Loss:  0.2942286875724793 Loss_std:  0.2899951491355896 Ortho_loss:  0 ortho total: 1.6537260234355933\n",
      "time:  73.26736569404602\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.02665601372718811\n",
      "Epoch:  76 Loss:  0.29233754274368284 Loss_std:  0.28809324743270875 Ortho_loss:  0 ortho total: 1.6579278409481046\n",
      "time:  73.86899065971375\n",
      "Epoch:  77 Loss:  0.31060875368118285 Loss_std:  0.3054309922599793 Ortho_loss:  0 ortho total: 2.02256306707859\n",
      "time:  72.4760262966156\n",
      "Epoch:  78 Loss:  0.298281402425766 Loss_std:  0.2921365631008148 Ortho_loss:  0 ortho total: 2.400327923893926\n",
      "time:  73.74611616134644\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.014500027894973755\n",
      "Epoch:  79 Loss:  0.29038094989776614 Loss_std:  0.28534082481384276 Ortho_loss:  0 ortho total: 1.9687986880540846\n",
      "time:  74.02505850791931\n",
      "Epoch:  80 Loss:  0.17791159154891967 Loss_std:  0.17149928107261658 Ortho_loss:  0 ortho total: 2.504808840155601\n",
      "time:  74.6604790687561\n",
      "0 0 0.7602627257799671\n",
      "model:  0 6684  out of  10000\n",
      "0 1 0.6440754039497307\n",
      "model:  1 6850  out of  10000\n",
      "1 0 0.5944525547445255\n",
      "1 1 0.7781786378047447\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.03518314063549042\n",
      "Epoch:  81 Loss:  0.16435426228523253 Loss_std:  0.15888781304359437 Ortho_loss:  0 ortho total: 2.135331657528877\n",
      "time:  73.0841281414032\n",
      "Epoch:  82 Loss:  0.17425248761177062 Loss_std:  0.1629957776737213 Ortho_loss:  0 ortho total: 4.397152289748192\n",
      "time:  74.65287733078003\n",
      "Epoch:  83 Loss:  0.18814296782493592 Loss_std:  0.17193305806159973 Ortho_loss:  0 ortho total: 6.331996074318883\n",
      "time:  75.02175045013428\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.01589880883693695\n",
      "Epoch:  84 Loss:  0.1861021744251251 Loss_std:  0.1795742575740814 Ortho_loss:  0 ortho total: 2.549967476725576\n",
      "time:  73.27544951438904\n",
      "Epoch:  85 Loss:  0.1961166337776184 Loss_std:  0.1892192192840576 Ortho_loss:  0 ortho total: 2.6943023353815088\n",
      "time:  74.96618413925171\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.009248861670494081\n",
      "Epoch:  86 Loss:  0.21018183779716493 Loss_std:  0.20174049550056458 Ortho_loss:  0 ortho total: 3.2973994046449673\n",
      "time:  76.65308833122253\n",
      "Epoch:  87 Loss:  0.21955671873569488 Loss_std:  0.21246926115512849 Ortho_loss:  0 ortho total: 2.7685381174087538\n",
      "time:  76.0348687171936\n",
      "Epoch:  88 Loss:  0.2246747110414505 Loss_std:  0.21835055376529694 Ortho_loss:  0 ortho total: 2.4703739345073688\n",
      "time:  72.15627694129944\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.014909300208091735\n",
      "Epoch:  89 Loss:  0.2374364891242981 Loss_std:  0.22853209812164307 Ortho_loss:  0 ortho total: 3.478277596831323\n",
      "time:  76.43275046348572\n",
      "Epoch:  90 Loss:  0.24230348755359649 Loss_std:  0.23436338020801545 Ortho_loss:  0 ortho total: 3.101604551076891\n",
      "time:  76.20073127746582\n",
      "Epoch:  91 Loss:  0.25209490725517275 Loss_std:  0.24341853625297546 Ortho_loss:  0 ortho total: 3.389207550883292\n",
      "time:  76.05182480812073\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.008734339475631714\n",
      "Epoch:  92 Loss:  0.256563246846199 Loss_std:  0.24765972754001617 Ortho_loss:  0 ortho total: 3.477937144041057\n",
      "time:  78.3696711063385\n",
      "Epoch:  93 Loss:  0.2645182808256149 Loss_std:  0.2550752869939804 Ortho_loss:  0 ortho total: 3.688669493794439\n",
      "time:  75.69370079040527\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.015537688136100772\n",
      "Epoch:  94 Loss:  0.26688684036254884 Loss_std:  0.259723108215332 Ortho_loss:  0 ortho total: 2.798333033919336\n",
      "time:  76.90108609199524\n",
      "Epoch:  95 Loss:  0.26954829389572144 Loss_std:  0.25999161943435667 Ortho_loss:  0 ortho total: 3.7330757945776\n",
      "time:  76.16808724403381\n",
      "Epoch:  96 Loss:  0.2739651102924347 Loss_std:  0.2660494844341278 Ortho_loss:  0 ortho total: 3.092041343450547\n",
      "time:  78.26322078704834\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.008345431089401245\n",
      "Epoch:  97 Loss:  0.28403618550300597 Loss_std:  0.27346469201087953 Ortho_loss:  0 ortho total: 4.129489868879317\n",
      "time:  76.48630547523499\n",
      "Epoch:  98 Loss:  0.281391543712616 Loss_std:  0.27484963857650757 Ortho_loss:  0 ortho total: 2.555431962013245\n",
      "time:  78.51569199562073\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.014325198531150819\n",
      "Epoch:  99 Loss:  0.28055381112098693 Loss_std:  0.2698452528476715 Ortho_loss:  0 ortho total: 4.183030876517296\n",
      "time:  76.39497303962708\n",
      "Epoch:  100 Loss:  0.2841077447986603 Loss_std:  0.27464404313087465 Ortho_loss:  0 ortho total: 3.696758633852006\n",
      "time:  75.14196252822876\n",
      "0 0 0.6837483325922632\n",
      "model:  0 5886  out of  10000\n",
      "0 1 0.7116887529731566\n",
      "model:  1 5847  out of  10000\n",
      "1 0 0.7007012142979305\n",
      "1 1 0.6790945406125166\n",
      "Epoch:  101 Loss:  0.28232747854232787 Loss_std:  0.27444189836502075 Ortho_loss:  0 ortho total: 3.0803048044443133\n",
      "time:  74.68369364738464\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.007074850797653198\n",
      "Epoch:  102 Loss:  0.28754354824066164 Loss_std:  0.2808369934844971 Ortho_loss:  0 ortho total: 2.619747993350029\n",
      "time:  75.3682963848114\n",
      "Epoch:  103 Loss:  0.2892452564716339 Loss_std:  0.2783803469944 Ortho_loss:  0 ortho total: 4.244105473160742\n",
      "time:  76.62563014030457\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.012569928169250488\n",
      "Epoch:  104 Loss:  0.28995182493209837 Loss_std:  0.2782585465049744 Ortho_loss:  0 ortho total: 4.567686933279038\n",
      "time:  76.56050443649292\n",
      "Epoch:  105 Loss:  0.2836316668319702 Loss_std:  0.2750022313117981 Ortho_loss:  0 ortho total: 3.3708730578422585\n",
      "time:  76.56915831565857\n",
      "Epoch:  106 Loss:  0.2903322972393036 Loss_std:  0.27796060046195986 Ortho_loss:  0 ortho total: 4.832694038748744\n",
      "time:  73.34788393974304\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.006811332702636719\n",
      "Epoch:  107 Loss:  0.28895976554870606 Loss_std:  0.27922927921295165 Ortho_loss:  0 ortho total: 3.8009712576866144\n",
      "time:  75.96966981887817\n",
      "Epoch:  108 Loss:  0.28945557113647463 Loss_std:  0.27950934621810913 Ortho_loss:  0 ortho total: 3.885244214534758\n",
      "time:  76.05141091346741\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.00723315179347992\n",
      "Epoch:  109 Loss:  0.288254916973114 Loss_std:  0.2799020282554626 Ortho_loss:  0 ortho total: 3.262847092747687\n",
      "time:  75.90567708015442\n",
      "Epoch:  110 Loss:  0.29196242694854735 Loss_std:  0.285142540473938 Ortho_loss:  0 ortho total: 2.664018312096595\n",
      "time:  73.21693992614746\n",
      "Epoch:  111 Loss:  0.28999528453826906 Loss_std:  0.27904956604003905 Ortho_loss:  0 ortho total: 4.275671270489688\n",
      "time:  74.16793298721313\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.006445214152336121\n",
      "Epoch:  112 Loss:  0.2951645044517517 Loss_std:  0.2855890114784241 Ortho_loss:  0 ortho total: 3.740427002310749\n",
      "time:  76.26703310012817\n",
      "Epoch:  113 Loss:  0.29146797869205476 Loss_std:  0.2833240669107437 Ortho_loss:  0 ortho total: 3.1812156230211253\n",
      "time:  75.49091649055481\n",
      "Epoch:  114 Loss:  0.2891507978916168 Loss_std:  0.278977954454422 Ortho_loss:  0 ortho total: 3.9737670212984044\n",
      "time:  75.74672245979309\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.008593541383743287\n",
      "Epoch:  115 Loss:  0.2921630104160309 Loss_std:  0.2822825752353668 Ortho_loss:  0 ortho total: 3.8595451176166518\n",
      "time:  75.14831972122192\n",
      "Epoch:  116 Loss:  0.2876631886291504 Loss_std:  0.28207057762146 Ortho_loss:  0 ortho total: 2.18461371064186\n",
      "time:  75.01854586601257\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.0058367192745208745\n",
      "Epoch:  117 Loss:  0.28993335145950316 Loss_std:  0.28353034692764284 Ortho_loss:  0 ortho total: 2.5011736780405043\n",
      "time:  72.77498888969421\n",
      "Epoch:  118 Loss:  0.28505933674812317 Loss_std:  0.2774606486606598 Ortho_loss:  0 ortho total: 2.9682373911142363\n",
      "time:  74.70140981674194\n",
      "Epoch:  119 Loss:  0.2853650315952301 Loss_std:  0.2760208959674835 Ortho_loss:  0 ortho total: 3.650052857398985\n",
      "time:  74.34011936187744\n",
      "pairs 1.0 conv 20 ortho loss conv:  0.003304219245910645\n",
      "Epoch:  120 Loss:  0.27520604871749876 Loss_std:  0.26729805068969725 Ortho_loss:  0 ortho total: 3.089061641693117\n",
      "time:  73.87103128433228\n",
      "0 0 0.6805339265850946\n",
      "model:  0 5868  out of  10000\n",
      "0 1 0.7215405589638718\n",
      "model:  1 5774  out of  10000\n",
      "1 0 0.7386560443366816\n",
      "1 1 0.6684438520493169\n"
     ]
    }
   ],
   "source": [
    "df_lotos = train_ensemble(num_models, clip_flag=True, lotos_flag=True, seed_val=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   epoch        t0        t1     acc\n",
      "0     60  0.822918  0.801362  0.9041\n",
      "1     60  0.793741  0.831162  0.9133\n",
      "2     80  0.813471  0.845274  0.9205\n",
      "3     80  0.783276  0.839479  0.9220\n",
      "4    100  0.827949  0.839946  0.9224\n",
      "5    100  0.822196  0.837275  0.9218\n",
      "6    120  0.819889  0.844490  0.9211\n",
      "7    120  0.818401  0.832954  0.9225\n"
     ]
    }
   ],
   "source": [
    "print(df_orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   epoch        t0        t1     acc\n",
      "0     60  0.737942  0.841133  0.9040\n",
      "1     60  0.820197  0.750277  0.9010\n",
      "2     80  0.731641  0.842678  0.9178\n",
      "3     80  0.843203  0.731644  0.9193\n",
      "4    100  0.684065  0.855859  0.9043\n",
      "5    100  0.885022  0.670641  0.9057\n",
      "6    120  0.656305  0.882955  0.9072\n",
      "7    120  0.859548  0.670304  0.9072\n"
     ]
    }
   ],
   "source": [
    "print(df_clip)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   epoch        t0        t1     acc\n",
      "0     60  0.771869  0.664330  0.8951\n",
      "1     60  0.597630  0.791373  0.8949\n",
      "2     80  0.760263  0.644075  0.9135\n",
      "3     80  0.594453  0.778179  0.9147\n",
      "4    100  0.683748  0.711689  0.8996\n",
      "5    100  0.700701  0.679095  0.9012\n",
      "6    120  0.680534  0.721541  0.8990\n",
      "7    120  0.738656  0.668444  0.9003\n"
     ]
    }
   ],
   "source": [
    "print(df_lotos)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculating the average transferability rate, robustness, and acc for each epoch for each ensemble:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_results = summarize(df_orig, num_models)\n",
    "clip_results = summarize(df_clip, num_models)\n",
    "lotos_results = summarize(df_lotos, num_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Orig ensemble:\n",
      "    epoch     trans    robust      acc\n",
      "0     60  0.797551  0.172960  0.90870\n",
      "1     80  0.814275  0.173525  0.92125\n",
      "2    100  0.831071  0.167388  0.92210\n",
      "3    120  0.831445  0.173578  0.92180\n",
      "\n",
      "C=1 ensemble:\n",
      "    epoch     trans    robust      acc\n",
      "0     60  0.830665  0.255890  0.90250\n",
      "1     80  0.842941  0.268358  0.91855\n",
      "2    100  0.870441  0.322647  0.90500\n",
      "3    120  0.871251  0.336695  0.90720\n",
      "\n",
      "LOTOS ensemble\n",
      "    epoch     trans    robust      acc\n",
      "0     60  0.630980  0.218379  0.89500\n",
      "1     80  0.619264  0.230779  0.91410\n",
      "2    100  0.706195  0.318579  0.90040\n",
      "3    120  0.730098  0.325511  0.89965\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "print('\\nOrig ensemble:\\n', orig_results)\n",
    "print('\\nC=1 ensemble:\\n', clip_results)\n",
    "print('\\nLOTOS ensemble\\n', lotos_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Making the plots similar to the ones in the paper and saving them to ```figs``` folder:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.clf()\n",
    "plt.plot(orig_results['epoch'], orig_results['trans'], label='Orig')\n",
    "plt.plot(clip_results['epoch'], clip_results['trans'], label='C=1')\n",
    "plt.plot(lotos_results['epoch'], lotos_results['trans'], label='LOTOS')\n",
    "plt.legend()\n",
    "plt.savefig('figs/transferability.png')\n",
    "plt.clf()\n",
    "\n",
    "plt.plot(orig_results['epoch'], orig_results['robust'], label='Orig')\n",
    "plt.plot(clip_results['epoch'], clip_results['robust'], label='C=1')\n",
    "plt.plot(lotos_results['epoch'], lotos_results['robust'], label='LOTOS')\n",
    "plt.legend()\n",
    "plt.savefig('figs/robustness.png')\n",
    "plt.clf()\n",
    "\n",
    "plt.plot(orig_results['epoch'], orig_results['acc'], label='Orig')\n",
    "plt.plot(clip_results['epoch'], clip_results['acc'], label='C=1')\n",
    "plt.plot(lotos_results['epoch'], lotos_results['acc'], label='LOTOS')\n",
    "plt.legend()\n",
    "plt.savefig('figs/accuracy.png')\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "th",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
