{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9010eeb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from absl import app\n",
    "from absl import flags\n",
    "from absl import logging\n",
    "import numpy as np  # pylint: disable=unused-import\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d840156",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS = flags.FLAGS\n",
    "\n",
    "flags.DEFINE_float('batch_norm_decay', 0.9, 'Batch norm decay parameter.')\n",
    "\n",
    "flags.DEFINE_float('rr_weight', 0.0,\n",
    "                   'Weight for the redundancy reduction term.')\n",
    "\n",
    "flags.DEFINE_bool('use_rr_loss', False, 'Use redundancy reduction term.')\n",
    "\n",
    "flags.DEFINE_bool('class_specific_rr_loss', True, 'Use class specific RR loss.')\n",
    "\n",
    "flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')\n",
    "\n",
    "flags.DEFINE_float('momentum', 0.9, 'momentum for SGD.')\n",
    "\n",
    "flags.DEFINE_integer('train_batch_size', 64, 'Batch size for training.')\n",
    "\n",
    "flags.DEFINE_integer('train_epochs', 100, 'Number of epochs to train for.')\n",
    "\n",
    "flags.DEFINE_integer('train_epochs_finetune', 50,\n",
    "                     'Number of epochs to finetune for.')\n",
    "\n",
    "flags.DEFINE_integer('train_steps', 0,\n",
    "                     'Number of train steps. If  > 0, overrides train epochs')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'train_steps_finetune', 0,\n",
    "    'Number of train finetune steps. If  > 0, overrides train epochs finetune')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'eval_steps', 0,\n",
    "    'Number of steps to eval for. If not provided, evals over entire dataset.')\n",
    "\n",
    "flags.DEFINE_integer('lr_decay_gap', 50, 'Gap between decaying lr')\n",
    "\n",
    "flags.DEFINE_float('lr_decay_factor', 0.5, 'Value to decay lr by')\n",
    "\n",
    "flags.DEFINE_integer('val_batch_size', 256, 'Batch size for eval.')\n",
    "\n",
    "flags.DEFINE_integer('test_batch_size', 256, 'Batch size for eval.')\n",
    "\n",
    "flags.DEFINE_integer('num_train', -1, 'Num training examples.')\n",
    "\n",
    "flags.DEFINE_integer('num_test', -1, 'Num test examples.')\n",
    "\n",
    "flags.DEFINE_integer('num_train_finetune', 20000,\n",
    "                     'Num training examples for finetuning.')\n",
    "\n",
    "flags.DEFINE_integer('num_test_finetune', 2000,\n",
    "                     'Num test examples for finetuning.')\n",
    "\n",
    "flags.DEFINE_integer('buffer_size', 256, 'Buffer size for shuffling.')\n",
    "\n",
    "flags.DEFINE_integer('checkpoint_epochs', 1,\n",
    "                     'Number of epochs between checkpoints/summaries.')\n",
    "\n",
    "flags.DEFINE_string('dataset', 'imagenette', 'Name of a dataset.')\n",
    "\n",
    "flags.DEFINE_bool(\n",
    "    'cache_dataset', False,\n",
    "    'Whether to cache the entire dataset in memory. If the dataset is '\n",
    "    'ImageNet, this is a very bad idea, but for smaller datasets it can '\n",
    "    'improve performance.')\n",
    "\n",
    "flags.DEFINE_bool('use_dropout_pretrain', False,\n",
    "                  'Whether to use dropout on the final layer for pretraining.')\n",
    "\n",
    "flags.DEFINE_float('dropout_rate', 0.0,\n",
    "                   'Dropout rate for final layer for pretraining.')\n",
    "\n",
    "flags.DEFINE_bool(\n",
    "    'finetune_from_random', False,\n",
    "    'Whether to initialize the finetuning model using random weights.')\n",
    "\n",
    "flags.DEFINE_enum(\n",
    "    'train_mode', 'pretrain', ['pretrain', 'finetune'],\n",
    "    'The train mode controls different objectives and trainable components.')\n",
    "\n",
    "flags.DEFINE_string('platform', 'GPU', 'To be run on GPU or TPU.')\n",
    "\n",
    "flags.DEFINE_float('validation_split', 0.2,\n",
    "                   'Validation split to use while training.')\n",
    "\n",
    "flags.DEFINE_string(\n",
    "    'checkpoint', None,\n",
    "    'Loading from the given checkpoint for fine-tuning if a finetuning '\n",
    "    'checkpoint does not already exist in model_dir.')\n",
    "\n",
    "flags.DEFINE_string('optimizer', 'sgd', 'Optimizer to be used.')\n",
    "\n",
    "flags.DEFINE_bool('project_out_prev_w', False, 'Project out previous w')\n",
    "\n",
    "flags.DEFINE_bool(\n",
    "    'zero_init_logits_layer', False,\n",
    "    'If True, zero initialize layers after avg_pool for supervised learning.')\n",
    "\n",
    "flags.DEFINE_bool('use_data_aug_with_DRO', False, 'True or False')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'fine_tune_after_block', -1,\n",
    "    'The layers after which block that we will fine-tune. -1 means fine-tuning '\n",
    "    'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '\n",
    "    'just the linear head.')\n",
    "\n",
    "flags.DEFINE_integer('keep_checkpoint_max', 5,\n",
    "                     'Maximum number of checkpoints to keep.')\n",
    "\n",
    "flags.DEFINE_enum('lr_decay_type', 'step_decay', ['step_decay', 'cosine_decay', 'warmup_cosine_decay'],\n",
    "                  'Kind of decay to be used for learning rate')\n",
    "\n",
    "flags.DEFINE_enum('learning_rate_scaling', 'linear', ['linear', 'sqrt'],\n",
    "                  'Learning rate scaling to use')\n",
    "\n",
    "flags.DEFINE_integer('warmup_steps', 0, 'Warmup steps for warmup and cosine decay')\n",
    "\n",
    "flags.DEFINE_integer('warmup_epochs', 20, 'Warmup epochs for warmup and cosine decay')\n",
    "\n",
    "flags.DEFINE_boolean(\n",
    "    'global_bn', True,\n",
    "    'Whether to aggregate BN statistics across distributed cores.')\n",
    "\n",
    "flags.DEFINE_integer('proj_dim', 20, 'Output dimension of projection head')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'num_heads', 1,\n",
    "    'Number of heads across which to decorrelate. One head is the standard config'\n",
    ")\n",
    "\n",
    "flags.DEFINE_integer('width_multiplier', 1,\n",
    "                     'Multiplier to change width of network.')\n",
    "\n",
    "flags.DEFINE_integer('resnet_depth', 50, 'Depth of ResNet.')\n",
    "\n",
    "flags.DEFINE_float(\n",
    "    'sk_ratio', 0.,\n",
    "    'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')\n",
    "\n",
    "flags.DEFINE_float('se_ratio', 0., 'If it is bigger than 0, it will enable SE.')\n",
    "\n",
    "flags.DEFINE_float('weight_decay', 1e-4, 'weight decay to be used')\n",
    "\n",
    "flags.DEFINE_float('logit_decay', 0.0, 'Decay to be used on logits')\n",
    "\n",
    "flags.DEFINE_integer('image_size', 32, 'Input image size.')\n",
    "\n",
    "flags.DEFINE_integer('val_epochs_gap', 10, 'Epoch gap to run validation after')\n",
    "\n",
    "flags.DEFINE_integer('finetune_val_epochs_gap', 10,\n",
    "                     'Epoch gap to run validation after')\n",
    "\n",
    "flags.DEFINE_integer('val_steps_gap', 0, 'Steps gap to run validation after')\n",
    "\n",
    "flags.DEFINE_integer('finetune_val_steps_gap', 0,\n",
    "                     'Steps gap to run validation after')\n",
    "\n",
    "flags.DEFINE_boolean('use_pretrained', True, 'whether to use pretrained model')\n",
    "\n",
    "flags.DEFINE_string(\n",
    "    'path', './',\n",
    "    'path for the dataset')\n",
    "\n",
    "flags.DEFINE_bool('use_OOD_transform', False,\n",
    "                  'Use data preprocessing specific to OOD dataset')\n",
    "\n",
    "flags.DEFINE_float('clip_norm', None, 'global clip norm for the gradient')\n",
    "\n",
    "flags.DEFINE_integer('num_runs', 1,\n",
    "                     'Number of runs to average the evaluations over')\n",
    "\n",
    "flags.DEFINE_integer('num_head_layers', 1, 'Number of layers to use in head')\n",
    "\n",
    "flags.DEFINE_bool('use_early_stopping', False, 'Whether to use early stopping based on validation accuracy or not')\n",
    "\n",
    "flags.DEFINE_integer('proj_layer', 0,\n",
    "                     'Layer in head where applying projection for rr')\n",
    "\n",
    "flags.DEFINE_integer('head_dim', 512, 'Dimension of head layers')\n",
    "\n",
    "flags.DEFINE_string('model_dir', None, 'Path for loading/saving a model')\n",
    "\n",
    "flags.DEFINE_string('model_finetune_dir', None,\n",
    "                    'Path for loading/saving the model after finetuning')\n",
    "\n",
    "flags.DEFINE_bool('use_seq_rr', True, 'Whether to use rr loss sequentially')\n",
    "\n",
    "flags.DEFINE_integer('num_seq_models', 2, 'Number of sequential models to train')\n",
    "\n",
    "flags.DEFINE_bool('load_model', False, 'whether to try to load model')\n",
    "\n",
    "flags.DEFINE_bool('save_model', False, 'whether to save model')\n",
    "\n",
    "flags.DEFINE_bool('lowerbound_rr', False, 'Whether to lowerbound rr loss')\n",
    "\n",
    "flags.DEFINE_float('lowerbound_factor', 0.5, 'If lowerbound rr, by what factor of expected value')\n",
    "\n",
    "flags.DEFINE_bool('use_exp_var_loss', False, 'Use explained away variance as loss')\n",
    "\n",
    "flags.DEFINE_bool('use_MI_loss', False, 'Use MI based loss function')\n",
    "\n",
    "flags.DEFINE_integer('CIFAR_label_1', 1, 'CIFAR class 1 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('CIFAR_label_2', 9, 'CIFAR class 2 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('MNIST_label_1', 0, 'MNIST class 1 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('MNIST_label_2', 1, 'MNIST class 2 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_float('corr_frac', 1.0, 'Correlation factor of MNIST and CIFAR for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_bool('use_proj_head', True, 'Use a projection head in the model')\n",
    "\n",
    "flags.DEFINE_bool('normalize_MI', False, 'Normalize MI based loss')\n",
    "\n",
    "flags.DEFINE_bool('use_logit_decorr', False, 'Use logit decorrelation')\n",
    "\n",
    "flags.DEFINE_bool('use_prob_decorr', False, 'Use probability decorrelation')\n",
    "\n",
    "flags.DEFINE_bool('use_val_for_MI', False, 'Use validation set for MI based loss')\n",
    "\n",
    "flags.DEFINE_bool('use_cifar_aug', False, 'Use CIFAR augmentation in MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_bool('use_mnist_aug', False, 'Use MNIST augmentation in MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('num_classes', 10, 'Number of classes')\n",
    "\n",
    "flags.DEFINE_bool('monitor_rr_grad_norms', False, 'Monitor rr gradient norms')\n",
    "\n",
    "flags.DEFINE_float('use_rr_after_frac', 0.0, 'Use rr after a fraction of steps')\n",
    "\n",
    "flags.DEFINE_bool('use_sq_MI', False, 'Use squared MI as loss instead of MI, for better gradients')\n",
    "\n",
    "flags.DEFINE_bool('use_disagr_loss', False, 'use disagreement based loss')\n",
    "\n",
    "flags.DEFINE_bool('normalize_MI_random', False, 'normalize MI by randomly shuffling probabilities')\n",
    "\n",
    "flags.DEFINE_bool('use_num_sq_MI', False, 'Use the MI loss of the form (MI^2)/y so as to get properly scaled gradients')\n",
    "\n",
    "flags.DEFINE_bool('use_stop_grad', False, 'Use stop gradient for MI normalization factor')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_loss', False, 'Use HSIC based independence test loss')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_diff', False, 'Use HSIC on logit difference')\n",
    "\n",
    "flags.DEFINE_bool('lin_scale_rr_weight', False, 'Linearly scale down rr weight as number of sequential models goes up')\n",
    "\n",
    "flags.DEFINE_integer('dataset_dim', 2, 'Dimension of LMS dataset')\n",
    "\n",
    "flags.DEFINE_integer('num_lin', 1, 'Number of linear dimensions')\n",
    "\n",
    "flags.DEFINE_integer('num_3_slabs', 1, 'Number of 3 slabs')\n",
    "\n",
    "flags.DEFINE_integer('num_5_slabs', 0, 'Number of 5 slabs')\n",
    "\n",
    "flags.DEFINE_integer('num_7_slabs', 0, 'Number of 7 slabs')\n",
    "\n",
    "flags.DEFINE_bool('use_random_transform', False, 'Use random transformation of input coordinates')\n",
    "\n",
    "flags.DEFINE_float('lin_margin', 0.1, 'Linear coordinate margin')\n",
    "\n",
    "flags.DEFINE_float('slab_margin', 0.05, 'Slab coordinate margin')\n",
    "\n",
    "flags.DEFINE_integer('fcn_layers', 3, 'Number of layers in FCN net for lms dataset')\n",
    "\n",
    "flags.DEFINE_integer('hidden_dim', 512, 'Hidden dimension for FCN')\n",
    "\n",
    "flags.DEFINE_bool('randomize_linear', False, 'Randomize linear coordinate in the dataset')\n",
    "\n",
    "flags.DEFINE_bool('randomize_slabs', False, 'Randomize slab coordinates in the dataset')\n",
    "\n",
    "flags.DEFINE_bool('turn_off_randomize_later', False, 'Turn off coordinate randomization later')\n",
    "\n",
    "flags.DEFINE_bool('use_L4_reg', False, 'Use L4 instead of L2 regularization')\n",
    "\n",
    "flags.DEFINE_bool('use_bn', False, 'Use BN in architecture')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_on_features', False, 'Use HSIC based loss on feature layers')\n",
    "\n",
    "flags.DEFINE_integer('HSIC_feature_layer', 0, 'Feature layer to use HSIC loss on')\n",
    "\n",
    "flags.DEFINE_multi_integer('HSIC_feature_layers', None, 'Feature layers to use HSIC loss on')\n",
    "\n",
    "flags.DEFINE_bool('use_all_features_HSIC', False, 'Use features at all the layers for HSIC loss')\n",
    "\n",
    "flags.DEFINE_bool('use_sq_HSIC', False, 'Square HSIC loss to manage gradients as loss goes down')\n",
    "\n",
    "flags.DEFINE_bool('use_GAP_HSIC_features', True, 'Use GAP on HSIC features')\n",
    "\n",
    "flags.DEFINE_bool('use_random_projections', False, 'Use random projections')\n",
    "\n",
    "flags.DEFINE_integer('random_proj_dim', 1, 'Random projection dimension')\n",
    "\n",
    "flags.DEFINE_bool('use_prev_logits_HSIC_features', False, 'Use logits of previous models for computing HSIC on features')\n",
    "\n",
    "flags.DEFINE_bool('use_MNIST_labels', False, 'Use MNIST labels in MNIST-CIFAR')\n",
    "\n",
    "flags.DEFINE_bool('switch_corr_later', False, 'Change correlation after first model')\n",
    "\n",
    "flags.DEFINE_bool('switch_labels_later', False, 'Switch whether to use CIFAR or MNIST labels after first model')\n",
    "\n",
    "flags.DEFINE_bool('monitor_EG_overlap', False, 'Monitor expected gradients overlap across models')\n",
    "\n",
    "flags.DEFINE_bool('monitor_robustness_measures', False, 'Monitor Gaussian, mask and RDE based robustness measures')\n",
    "\n",
    "flags.DEFINE_bool('monitor_error_diversity', False, 'Monitor error diversity')\n",
    "\n",
    "flags.DEFINE_bool('monitor_logit_correlation', False, 'Monitor logit correlation')\n",
    "\n",
    "flags.DEFINE_bool('sep_short_direct_branch', False, 'Separately make shortcut and direct branch independent of previous model')\n",
    "\n",
    "flags.DEFINE_bool('use_pretrained_model_1', False, 'Utilise a pretrained first model.')\n",
    "\n",
    "flags.DEFINE_string('pretrained_model_path', None, 'Path for first pretrained model')\n",
    "\n",
    "flags.DEFINE_multi_string('pretrained_model_paths', None, 'Paths for pretrained models')\n",
    "\n",
    "flags.DEFINE_multi_string('pretrained_checkpoint_paths', None, 'Paths for pretrained checkpoints')\n",
    "\n",
    "flags.DEFINE_bool('use_indexed_checkpoints', False, 'Use particular checkpoint index')\n",
    "\n",
    "flags.DEFINE_bool('check_tf_func', False, 'Check tf func')\n",
    "\n",
    "flags.DEFINE_bool('use_FCN', False, 'Use FCN architecture')\n",
    "\n",
    "flags.DEFINE_bool('monitor_EG_loss', False, 'Monitor EG loss')\n",
    "\n",
    "flags.DEFINE_bool('use_equal_split', False, 'Use equal split for DRO setting')\n",
    "\n",
    "flags.DEFINE_bool('use_EG_loss', False, 'Use expected gradients loss for avoiding collapse')\n",
    "\n",
    "flags.DEFINE_integer('num_ref_EG_loss', 2, 'Number of referneces in EG loss')\n",
    "\n",
    "flags.DEFINE_float('EG_loss_weight', 1e-3, 'Weight of EG loss')\n",
    "\n",
    "flags.DEFINE_bool('binary_classification', True, 'Use binary classification and logistic loss')\n",
    "\n",
    "flags.DEFINE_bool('use_color_labels', False, 'Use colors for label in color-MNIST or binary-color-MNIST')\n",
    "\n",
    "flags.DEFINE_bool('use_CNN', False, 'Use custom CNN architecture')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_ratio', False, 'Use HSIC ratio as the loss')\n",
    "\n",
    "flags.DEFINE_string(\n",
    "    'master', 'local',\n",
    "    \"BNS name of the TensorFlow master to use. 'local' for GPU.\")\n",
    "\n",
    "flags.DEFINE_integer('project_out_rank', 0, 'Projecting certain dimensions out of input')\n",
    "\n",
    "flags.DEFINE_float('project_out_factor', 0.0, 'Projecting out factor')\n",
    "\n",
    "flags.DEFINE_float('eig_cutoff_factor', 0.0, 'Eigenvalue cutoff factor')\n",
    "\n",
    "flags.DEFINE_integer('check_ranks_max', 10, 'Check rank of 1st hidden matrix')\n",
    "\n",
    "flags.DEFINE_multi_integer('filters', [16, 32, 64], 'Filters to be used in a CNN')\n",
    "\n",
    "flags.DEFINE_bool('standardize_mean_reps', True, 'Standardize mean of the representations')\n",
    "\n",
    "flags.DEFINE_multi_integer('kernel_sizes', [3, 3, 3], 'kernel sizes to be used in a CNN')\n",
    "\n",
    "flags.DEFINE_multi_integer('strides', [1, 2, 1], 'Strides to be used in a CNN')\n",
    "\n",
    "flags.DEFINE_multi_integer('project_out_vecs', [1,2], 'Number of top SVD vectors to project out')\n",
    "\n",
    "flags.DEFINE_bool('use_chizat_init', False, 'Whether to use chizat-bach initialization in head')\n",
    "\n",
    "flags.DEFINE_bool('project_out_w', False, 'Project out w from representations')\n",
    "\n",
    "flags.DEFINE_bool('use_complete_corr', False, 'Use complete correlation in DRO setting')\n",
    "\n",
    "flags.DEFINE_bool('use_complete_corr_test', False, 'Use complete correlation in DRO setting')\n",
    "\n",
    "flags.DEFINE_bool('flip_err_div_for_minority', False, 'Flip error diversity calc for minority classes')\n",
    "\n",
    "flags.DEFINE_bool('measure_feat_robust', False, 'Measure robustness of features')\n",
    "\n",
    "flags.DEFINE_float('max_gauss_noise_std', 5.0,  'Maximum gaussian noise std')\n",
    "\n",
    "flags.DEFINE_boolean('use_tpu', True, 'Should we use TPU?')\n",
    "\n",
    "flags.DEFINE_bool('check_torch_reps', False, 'Check torch reps')\n",
    "\n",
    "flags.DEFINE_boolean(\n",
    "    'train_split', 1,\n",
    "    'Use train validation split while training, If set to false, use entire training dataset'\n",
    ")\n",
    "\n",
    "flags.DEFINE_bool('finetune_only_linear_head', False, 'Finetune only linear head')\n",
    "\n",
    "_FRAC_POISON = flags.DEFINE_float('frac_poison', 0.,\n",
    "                                  'Fraction of poisoned examples.')\n",
    "\n",
    "_TASK_ID = flags.DEFINE_enum('task_id', 'Imagenette', [\n",
    "    'Data-poisoning', 'DRO', 'Few-shot', 'CIFAR-10.2', 'CIFAR-10.2-finetune',\n",
    "    'CINIC', 'CINIC-finetune', 'Imagenette'\n",
    "], 'Specify the task that needs to be run.')\n",
    "\n",
    "_FINETUNE_ONLY_HEAD = flags.DEFINE_bool(\n",
    "    'finetune_only_head', False, 'whether to finetune head or the entire model')\n",
    "\n",
    "_TRAIN_CLASSES = flags.DEFINE_multi_integer(\n",
    "    'train_classes', [0, 1, 2, 3, 4], 'classes to train for few-shot learning')\n",
    "\n",
    "_FINETUNE_CLASSES = flags.DEFINE_multi_integer(\n",
    "    'finetune_classes', [5, 6, 7, 8, 9],\n",
    "    'classes to finetune for few-shot learning')\n",
    "\n",
    "TASK_STAGES = {\n",
    "    'Data-poisoning': ['Train'],\n",
    "    'DRO': ['Train'],\n",
    "    'Few-shot': ['Train', 'Finetune'],\n",
    "    'CIFAR-10.2': ['Train'],\n",
    "    'CIFAR-10.2-finetune': ['Train'],\n",
    "    'CINIC': ['Train'],\n",
    "    'CINIC-finetune': ['Train'],\n",
    "    'CIFAR-MNIST': ['Train'],\n",
    "    'LMS': ['Train'],\n",
    "    'MNIST': ['Train'],\n",
    "    'color-MNIST': ['Train'],\n",
    "    'Imagenette': ['Train'],\n",
    "}\n",
    "\n",
    "flags.DEFINE_string('f', '', 'kernel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dc50f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class head_model(tf.keras.Model):\n",
    "\n",
    "    def __init__(self,\n",
    "               base_model,\n",
    "               num_classes,\n",
    "               proj_dim,\n",
    "               proj_layer,\n",
    "               head_dims,\n",
    "               use_proj=True,\n",
    "               use_bn=True,\n",
    "               use_relu=True,\n",
    "               dropout_rate=0.0):\n",
    "        super(head_model, self).__init__()\n",
    "        self.base_model = base_model\n",
    "        if proj_layer > len(head_dims):\n",
    "            raise ValueError('proj layer must be less than head dimensions')\n",
    "        self.proj_head = []\n",
    "        for i in range(proj_layer):\n",
    "            self.proj_head.append(tf.keras.layers.Dense(head_dims[i]))\n",
    "            if use_bn:\n",
    "                self.proj_head.append(norm_layer())\n",
    "            if use_relu:\n",
    "                self.proj_head.append(tf.keras.layers.ReLU())\n",
    "\n",
    "        if use_proj:\n",
    "            self.proj_head.append(tf.keras.layers.Dense(proj_dim))\n",
    "\n",
    "        self.head = []\n",
    "        self.head.append(tf.keras.layers.Dropout(rate=dropout_rate))\n",
    "        for i in range(len(head_dims) - proj_layer):\n",
    "            self.head.append(tf.keras.layers.Dense(head_dims[proj_layer + i]))\n",
    "            if use_bn:\n",
    "                self.head.append(norm_layer())\n",
    "            if use_relu:\n",
    "                self.head.append(tf.keras.layers.ReLU())\n",
    "\n",
    "        if FLAGS.use_chizat_init:\n",
    "            if len(head_dims) > 1:\n",
    "                raise ValueError('Rich regime initialization only for 1 layer net')\n",
    "            else:\n",
    "                self.out = tf.keras.layers.Dense(num_classes, use_bias=False)\n",
    "        else:\n",
    "            self.out = tf.keras.layers.Dense(num_classes)\n",
    "\n",
    "        self.head_dims = head_dims\n",
    "\n",
    "    def call(self, x, training, only_head=False, only_linear_head=False, init=False, project_out_w=False, project_out_mat=None,\n",
    "           gauss_noise_feats=False, sigma=1.0, rand_sigma=False):\n",
    "        if FLAGS.use_pretrained:\n",
    "            x = self.base_model(x, training=training and not only_head)\n",
    "            if gauss_noise_feats:\n",
    "                if rand_sigma:\n",
    "                    std_devs = sigma * tf.random.uniform([tf.shape(x)[0]])\n",
    "                    final_shape = tf.constant([-1], dtype=tf.int32)\n",
    "                    final_shape = tf.concat([final_shape, tf.ones([tf.rank(x)-1], dtype=tf.int32)], axis=0)\n",
    "                    z = tf.reshape(std_devs, final_shape) * tf.random.normal(tf.shape(x))\n",
    "                else:\n",
    "                    z = sigma * tf.random.normal(tf.shape(x))\n",
    "                x = x + tf.cast(z, x.dtype)\n",
    "            feat = []\n",
    "        else:\n",
    "            x, feat = self.base_model(x, training=training and not only_head)\n",
    "        for layer in self.proj_head:\n",
    "            x = layer(x, training=training and not only_linear_head)\n",
    "        reps = x\n",
    "        if project_out_w and FLAGS.use_pretrained:\n",
    "            for ind, layer in enumerate(self.head):\n",
    "                if isinstance(layer, tf.keras.layers.Dense):\n",
    "                    for var in layer.trainable_variables:\n",
    "                        if 'kernel' in var.name:\n",
    "                            W_norm = tf.norm(var, axis=0, keepdims=True)**2\n",
    "                            sample_ind = tf.random.categorical(tf.math.log(W_norm), 1)[0,0]\n",
    "                            sample_W = tf.reshape(var[:, sample_ind], [-1,1])\n",
    "                            ctx = tf.distribute.get_replica_context()\n",
    "                            sample_W_gather = ctx.all_gather(sample_W, axis=1)\n",
    "                            sample_W = tf.reshape(sample_W_gather[:,0], [-1,1])\n",
    "                            sample_W = sample_W/tf.norm(sample_W, axis=0, keepdims=True)\n",
    "                            proj_mat = tf.eye(tf.shape(var)[0]) - tf.linalg.matmul(sample_W, tf.transpose(sample_W))\n",
    "                            x = tf.transpose(tf.linalg.matmul(proj_mat, tf.transpose(x)))\n",
    "                            break\n",
    "        elif project_out_mat is not None:\n",
    "            pinv = tf.linalg.pinv(tf.linalg.matmul(tf.transpose(project_out_mat), project_out_mat))\n",
    "            proj_mat = tf.linalg.matmul(project_out_mat, tf.linalg.matmul(pinv, tf.transpose(project_out_mat)))\n",
    "            proj_x = tf.linalg.matmul(proj_mat, tf.transpose(x))\n",
    "            x = x - tf.transpose(proj_x)\n",
    "        curr_ind = 0\n",
    "        for layer in self.head:\n",
    "            if FLAGS.use_chizat_init and isinstance(layer, tf.keras.layers.Dense):\n",
    "                if init:\n",
    "                    tf.print('assigned')\n",
    "                    z = tf.random.normal((tf.shape(x)[1]+1, self.head_dims[0]))\n",
    "                    z = z/tf.norm(z, axis=0, keepdims=True)\n",
    "                    layer.kernel.assign(z[0:tf.shape(x)[1], :])\n",
    "                    layer.bias.assign(z[tf.shape(x)[1], :])\n",
    "            x = layer(x, training=training and not only_linear_head)\n",
    "            if FLAGS.use_pretrained:\n",
    "                if isinstance(layer, tf.keras.layers.Dense):\n",
    "                    if curr_ind in FLAGS.HSIC_feature_layers or FLAGS.use_all_features_HSIC:\n",
    "                        feat.append(x)\n",
    "                    curr_ind += 1\n",
    "        if FLAGS.use_chizat_init:\n",
    "            if init:\n",
    "                tf.print('out assigned')\n",
    "                samples = tf.random.categorical(tf.math.log([[0.5, 0.5]]), tf.shape(x)[1]*FLAGS.num_classes)\n",
    "                self.out.kernel.assign(tf.reshape(2.0*tf.cast(samples, tf.float32) - 1.0, self.out.kernel.shape))\n",
    "        x = self.out(x, training=training)\n",
    "\n",
    "        if FLAGS.use_chizat_init:\n",
    "            x = x/tf.cast(self.head_dims[0], tf.float32)\n",
    "        if FLAGS.use_all_features_HSIC or -1 in FLAGS.HSIC_feature_layers:\n",
    "            feat.append(x)\n",
    "        return reps, x, feat\n",
    "\n",
    "def createmodel(num_classes,\n",
    "                head_dim,\n",
    "                head_layers,\n",
    "                proj_dim,\n",
    "                proj_layer,\n",
    "                use_proj=True,\n",
    "                resnet_base=None,\n",
    "                dropout_rate=0.0,\n",
    "                num_heads=1):\n",
    "    \"\"\"Create a model with resnet base and a linear head.\"\"\"\n",
    "    if resnet_base is None:\n",
    "        resnet_base = resnet.resnet(\n",
    "            resnet_depth=FLAGS.resnet_depth,\n",
    "            width_multiplier=FLAGS.width_multiplier,\n",
    "            cifar_stem=FLAGS.image_size <= 32)\n",
    "    if num_heads > 1:\n",
    "        model = multihead_model(\n",
    "            resnet_base,\n",
    "            num_classes,\n",
    "            proj_dim,\n",
    "            proj_layer,\n",
    "            head_dim + np.zeros(head_layers),\n",
    "            num_heads,\n",
    "            use_proj=use_proj,\n",
    "            dropout_rate=dropout_rate)\n",
    "    else:\n",
    "        model = head_model(\n",
    "            resnet_base,\n",
    "            num_classes,\n",
    "            proj_dim,\n",
    "            proj_layer,\n",
    "            head_dim + np.zeros(head_layers),\n",
    "            use_proj = use_proj,\n",
    "            dropout_rate=dropout_rate,\n",
    "            use_bn=FLAGS.use_bn)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6baf9432",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS.rr_weight = 0.0\n",
    "FLAGS.use_rr_loss = False\n",
    "FLAGS.class_specific_rr_loss = False\n",
    "FLAGS.learning_rate = 0.5\n",
    "FLAGS.momentum = 0.0\n",
    "FLAGS.train_batch_size = 256\n",
    "FLAGS.train_epochs = 300\n",
    "FLAGS.train_epochs_finetune = 50\n",
    "FLAGS.train_steps = 10\n",
    "FLAGS.train_steps_finetune = 5\n",
    "FLAGS.val_batch_size = 256\n",
    "FLAGS.test_batch_size = 256\n",
    "FLAGS.buffer_size = 10000\n",
    "FLAGS.checkpoint_epochs = 50\n",
    "FLAGS.dataset = 'imagenette'\n",
    "FLAGS.dropout_rate = 0.0\n",
    "FLAGS.platform = 'GPU'\n",
    "FLAGS.checkpoint = None\n",
    "FLAGS.optimizer = 'sgd'\n",
    "FLAGS.project_out_prev_w = False\n",
    "FLAGS.use_data_aug_with_DRO = False\n",
    "FLAGS.lr_decay_type = 'warmup_cosine_decay'\n",
    "FLAGS.learning_rate_scaling = 'linear'\n",
    "FLAGS.warmup_steps = 500\n",
    "FLAGS.warmup_epochs = 20\n",
    "FLAGS.proj_dim = 20\n",
    "FLAGS.num_heads = 1\n",
    "FLAGS.resnet_depth = 50\n",
    "FLAGS.weight_decay = 1e-4\n",
    "FLAGS.image_size = 32\n",
    "FLAGS.val_epochs_gap = 1\n",
    "FLAGS.finetune_val_epochs_gap = 1\n",
    "FLAGS.val_steps_gap = 0\n",
    "FLAGS.finetune_val_steps_gap = 0\n",
    "FLAGS.use_pretrained = True\n",
    "FLAGS.path = './'\n",
    "FLAGS.use_OOD_transform = False\n",
    "FLAGS.num_runs = 1\n",
    "FLAGS.num_head_layers = 1\n",
    "FLAGS.use_early_stopping = False\n",
    "FLAGS.proj_layer = 0\n",
    "FLAGS.head_dim = 100\n",
    "FLAGS.model_dir = './'\n",
    "FLAGS.model_finetune_dir = None\n",
    "FLAGS.use_seq_rr = False\n",
    "FLAGS.num_seq_models = 1\n",
    "FLAGS.load_model = False\n",
    "FLAGS.save_model = True\n",
    "FLAGS.use_exp_var_loss = False\n",
    "FLAGS.use_MI_loss = False\n",
    "FLAGS.CIFAR_label_1 = 1\n",
    "FLAGS.CIFAR_label_2 = 9\n",
    "FLAGS.MNIST_label_1 = 0\n",
    "FLAGS.MNIST_label_2 = 1\n",
    "FLAGS.corr_frac = 1.0\n",
    "FLAGS.use_proj_head = False\n",
    "FLAGS.num_classes = 2\n",
    "FLAGS.monitor_rr_grad_norms = False\n",
    "FLAGS.use_HSIC_loss = False\n",
    "FLAGS.use_HSIC_diff = False\n",
    "FLAGS.lin_scale_rr_weight = False\n",
    "FLAGS.dataset_dim = 2\n",
    "FLAGS.num_lin = 1\n",
    "FLAGS.num_3_slabs = 1\n",
    "FLAGS.num_5_slabs = 0\n",
    "FLAGS.num_7_slabs = 0\n",
    "FLAGS.use_random_transform = False\n",
    "FLAGS.lin_margin = 0.1\n",
    "FLAGS.slab_margin = 0.05\n",
    "FLAGS.fcn_layers = 3\n",
    "FLAGS.hidden_dim = 512\n",
    "FLAGS.randomize_linear = False\n",
    "FLAGS.randomize_slabs = False\n",
    "FLAGS.turn_off_randomize_later = False\n",
    "FLAGS.use_L4_reg = False\n",
    "FLAGS.use_bn = False\n",
    "FLAGS.use_HSIC_on_features = False\n",
    "FLAGS.HSIC_feature_layer = 0\n",
    "FLAGS.HSIC_feature_layers = [0,1]\n",
    "FLAGS.use_all_features_HSIC = False\n",
    "FLAGS.use_sq_HSIC = False\n",
    "FLAGS.use_GAP_HSIC_features = True\n",
    "FLAGS.use_random_projections = False\n",
    "FLAGS.random_proj_dim = 1\n",
    "FLAGS.use_prev_logits_HSIC_features = False\n",
    "FLAGS.use_MNIST_labels = False\n",
    "FLAGS.monitor_EG_overlap = False\n",
    "FLAGS.monitor_robustness_measures = True\n",
    "FLAGS.monitor_error_diversity = True\n",
    "FLAGS.monitor_logit_correlation = True\n",
    "FLAGS.use_pretrained_model_1 = False\n",
    "FLAGS.pretrained_model_path = None\n",
    "FLAGS.pretrained_model_paths = None\n",
    "FLAGS.pretrained_checkpoint_paths = None\n",
    "FLAGS.use_indexed_checkpoints = False\n",
    "FLAGS.use_FCN = False\n",
    "FLAGS.use_EG_loss = False\n",
    "FLAGS.num_ref_EG_loss = 2\n",
    "FLAGS.EG_loss_weight = 1e-3\n",
    "FLAGS.binary_classification = True\n",
    "FLAGS.use_CNN = False\n",
    "FLAGS.master = 'local'\n",
    "FLAGS.project_out_rank = 0\n",
    "FLAGS.check_ranks_max = 30\n",
    "FLAGS.filters = [16, 32, 64]\n",
    "FLAGS.kernel_sizes = [3, 3, 3]\n",
    "FLAGS.strides = [1, 2, 1]\n",
    "FLAGS.project_out_vecs = [1,2]\n",
    "FLAGS.use_chizat_init = True\n",
    "FLAGS.project_out_w = False\n",
    "FLAGS.use_complete_corr = False\n",
    "FLAGS.use_complete_corr_test = False\n",
    "FLAGS.flip_err_div_for_minority = True\n",
    "FLAGS.measure_feat_robust = True\n",
    "FLAGS.max_gauss_noise_std = 5.0\n",
    "FLAGS.use_tpu = False\n",
    "FLAGS.check_torch_reps = False\n",
    "FLAGS.train_split = True\n",
    "FLAGS.finetune_only_linear_head = False\n",
    "FLAGS.task_id = 'Imagenette'\n",
    "FLAGS.finetune_only_head = True\n",
    "FLAGS.standardize_mean_reps = True\n",
    "\n",
    "_SHUFFLE_BUFFER_SIZE = flags.DEFINE_integer('shuffle_buffer_size', 1000,\n",
    "                                            'Buffer size for shuffling.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9fc2e49",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS(sys.argv)\n",
    "\n",
    "FLAGS.binary_classification=False\n",
    "FLAGS.num_classes = 10\n",
    "FLAGS.dataset = 'mnist-cifar'\n",
    "\n",
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6380748c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-multirun-mean/run-2/2/0\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)\n",
    "print(status)\n",
    "print(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b392856c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imagenette_train_transform(x):\n",
    "    mean = [0.485, 0.456, 0.406]\n",
    "    std = [0.229, 0.224, 0.225]\n",
    "    image = tf.image.resize(x['image'], size=[256, 256])\n",
    "    image = tf.image.random_crop(image, size=(224, 224, 3))\n",
    "    image = tf.image.random_flip_left_right(image)\n",
    "    image = image/255.0\n",
    "    image = (image - mean)/std\n",
    "    return {'image': image, 'label': x['label']}\n",
    "\n",
    "def imagenette_val_transform(x):\n",
    "    mean = [0.485, 0.456, 0.406]\n",
    "    std = [0.229, 0.224, 0.225]\n",
    "    image = tf.image.resize(x['image'], size=[256,256])\n",
    "    image = tf.image.central_crop(image, central_fraction=224.0/256.0)\n",
    "    image = image/255.0\n",
    "    image = (image - mean) / std\n",
    "\n",
    "    return {'image': image, 'label': x['label']}\n",
    "\n",
    "def load_imagenette_reps(path, binary=True, batched=True):\n",
    "  if binary:\n",
    "    filepath = os.path.join(path, 'b-imagenette.pkl')\n",
    "  else:\n",
    "    filepath = os.path.join(path, 'imagenette.pkl')\n",
    "  with open(filepath, 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "  \n",
    "  train_data = data['train']\n",
    "  train_image = train_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    train_image = train_image - np.mean(train_image, axis=0, keepdims=True)\n",
    "  train_label = train_data['labels']\n",
    "  train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': train_image,\n",
    "      'label': train_label,\n",
    "  })\n",
    "  train_ds = train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  val_data = data['val']\n",
    "  val_image = val_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    val_image = val_image - np.mean(val_image, axis=0, keepdims=True)\n",
    "  val_label = val_data['labels']\n",
    "  val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': val_image,\n",
    "      'label': val_label,\n",
    "  })\n",
    "  val_ds = val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  #use validation as test\n",
    "  test_ds = val_ds\n",
    "  \n",
    "  if binary:\n",
    "    return train_ds, val_ds, test_ds, train_image.shape[0]/100\n",
    "  else:\n",
    "    return train_ds, val_ds, test_ds, train_image.shape[0]/20\n",
    "\n",
    "def load_mnistcifar_reps(path, batched=True):\n",
    "  filepath = os.path.join(path, 'cifar-mnist.pkl')\n",
    "  with tf.io.gfile.GFile(filepath, 'rb') as fobj:\n",
    "    data = pickle.load(fobj)\n",
    "  \n",
    "  if FLAGS.train_split:\n",
    "    train_data = data['train_split']\n",
    "  else:\n",
    "    train_data = data['train']\n",
    "  train_image = train_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    train_image = train_image - np.mean(train_image, axis=0, keepdims=True)\n",
    "  train_label = train_data['labels']\n",
    "  train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': train_image,\n",
    "      'label': train_label,\n",
    "  })\n",
    "  train_ds = train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  val_data = data['val']\n",
    "  val_image = val_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    val_image = val_image - np.mean(val_image, axis=0, keepdims=True)\n",
    "  val_label = val_data['labels']\n",
    "  val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': val_image,\n",
    "      'label': val_label,\n",
    "  })\n",
    "  val_ds = val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "\n",
    "  test_data = data['test']\n",
    "  test_image = test_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    test_image = test_image - np.mean(test_image, axis=0, keepdims=True)\n",
    "  test_label = test_data['labels']\n",
    "  test_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': test_image,\n",
    "      'label': test_label,\n",
    "  })\n",
    "  test_ds = test_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  if FLAGS.train_split:\n",
    "    OOD_train_data = data['OOD_train_split']\n",
    "  else:\n",
    "    OOD_train_data = data['OOD_train']\n",
    "  OOD_train_image = OOD_train_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    OOD_train_image = OOD_train_image - np.mean(OOD_train_image, axis=0, keepdims=True)\n",
    "  OOD_train_label = OOD_train_data['labels']\n",
    "  OOD_train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': OOD_train_image,\n",
    "      'label': OOD_train_label,\n",
    "  })\n",
    "  OOD_train_ds = OOD_train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  OOD_val_data = data['OOD_val']\n",
    "  OOD_val_image = OOD_val_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    OOD_val_image = OOD_val_image - np.mean(OOD_val_image, axis=0, keepdims=True)\n",
    "  OOD_val_label = OOD_val_data['labels']\n",
    "  OOD_val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': OOD_val_image,\n",
    "      'label': OOD_val_label,\n",
    "  })\n",
    "  OOD_val_ds = OOD_val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "\n",
    "  OOD_test_data = data['OOD_test']\n",
    "  OOD_test_image = OOD_test_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    OOD_test_image = OOD_test_image - np.mean(OOD_test_image, axis=0, keepdims=True)\n",
    "  OOD_test_label = OOD_test_data['labels']\n",
    "  OOD_test_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': OOD_test_image,\n",
    "      'label': OOD_test_label,\n",
    "  })\n",
    "  OOD_test_ds = OOD_test_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  return train_ds, val_ds, test_ds, OOD_train_ds, OOD_val_ds, OOD_test_ds, train_image.shape[0], OOD_train_image.shape[0]\n",
    "\n",
    "def load_waterbirds_reps_2(path):\n",
    "  filepath = os.path.join(path, 'waterbirds_train.pkl')\n",
    "  with open(filepath, 'rb') as f:\n",
    "    train_data = pickle.load(f)\n",
    "  \n",
    "  if FLAGS.use_complete_corr:\n",
    "    train_data = train_data['waterbirds_train_complete_corr']\n",
    "  else:\n",
    "    train_data = train_data['waterbirds_train']\n",
    "  train_image = train_data['inputs']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    train_image = train_image - np.mean(train_image, axis=0, keepdims=True)\n",
    "  train_label = train_data['labels']\n",
    "  train_metadata = train_data['metadata']\n",
    "  train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': train_image,\n",
    "      'label': train_label,\n",
    "      'metadata': train_metadata,\n",
    "  })\n",
    "  train_ds = train_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  \n",
    "  filepath = os.path.join(path, 'waterbirds_val_test.pkl')\n",
    "  with open(filepath, 'rb') as f:\n",
    "    val_test_data = pickle.load(f)\n",
    "  \n",
    "  if FLAGS.use_complete_corr_test:\n",
    "    val_data = val_test_data['waterbirds_val_complete_corr']\n",
    "  else:\n",
    "    val_data = val_test_data['waterbirds_val']\n",
    "  val_image = val_data['inputs']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    val_image = val_image - np.mean(val_image, axis=0, keepdims=True)\n",
    "  val_label = val_data['labels']\n",
    "  val_metadata = val_data['metadata']\n",
    "  val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': val_image,\n",
    "      'label': val_label,\n",
    "      'metadata': val_metadata,\n",
    "  })\n",
    "  val_ds = val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  \n",
    "  if FLAGS.use_complete_corr_test:\n",
    "    test_data = val_test_data['waterbirds_test_complete_corr']\n",
    "  else:\n",
    "    test_data = val_test_data['waterbirds_test']\n",
    "  test_image = test_data['inputs']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    test_image = test_image - np.mean(test_image, axis=0, keepdims=True)\n",
    "  test_label = test_data['labels']\n",
    "  test_metadata = test_data['metadata']\n",
    "  test_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': test_image,\n",
    "      'label': test_label,\n",
    "      'metadata': test_metadata,\n",
    "  })\n",
    "  test_ds = test_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  \n",
    "  if FLAGS.use_complete_corr:\n",
    "    return train_ds, val_ds, test_ds, 4555\n",
    "  else:\n",
    "    return train_ds, val_ds, test_ds, 4795"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfb29086",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "FLAGS.train_split = True\n",
    "FLAGS.use_complete_corr = False\n",
    "FLAGS.use_complete_corr_test = False\n",
    "\n",
    "#train_ds, val_ds, test_ds, train_len = load_waterbirds_reps_2('./data_reps')\n",
    "train_ds, val_ds, test_ds, train_len = load_imagenette_reps('./data_reps', binary=False, batched=False)\n",
    "#train_ds, val_ds, test_ds, OOD_train_ds, OOD_val_ds, OOD_test_ds, train_len, OOD_train_len = load_mnistcifar_reps('./data_reps')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e69cdf4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS.measure_feat_robust = True\n",
    "\n",
    "def gauss_noise_robust_2(models, data, std_dev):\n",
    "  std_devs = [std_dev]\n",
    "  total_corr = 0.0\n",
    "  for std_dev in std_devs:\n",
    "    # tf.autograph.experimental.set_loop_options(\n",
    "    #     shape_invariants=[(total_corr, tf.TensorShape([None]))]\n",
    "    # )\n",
    "    if FLAGS.measure_feat_robust:\n",
    "      feats = models[0].layers[0](data['image'], training=False)\n",
    "      z = std_dev * tf.random.normal(tf.shape(feats))\n",
    "      noise_feats = feats + tf.cast(z, feats.dtype)\n",
    "      avg = 0.0\n",
    "      for i in range(len(models)):\n",
    "        x = noise_feats\n",
    "        for ind, layer in enumerate(models[i].layers):\n",
    "          if ind > 0:\n",
    "            x = layer(x, training=False)\n",
    "        avg = (i/(i+1.0))*avg + (1.0/(i+1.0))*x\n",
    "    if FLAGS.binary_classification:\n",
    "      corr = tf.cast(tf.reshape(avg, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "      corr = tf.argmax(avg, axis=1) == data['label']\n",
    "\n",
    "    if FLAGS.dataset=='waterbirds':\n",
    "      corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "      corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "      corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "      corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "      corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "      corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "      corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "      corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "      corr_adj_frac = (corr_00_frac*3498 + corr_01_frac*184 + corr_10_frac*56 + corr_11_frac*1057)/4795\n",
    "      total_corr += corr_adj_frac * tf.cast(tf.shape(data['label'])[0], tf.float32)\n",
    "    else:\n",
    "      total_corr += tf.reduce_sum(tf.cast(corr, tf.float32))\n",
    "\n",
    "  final_score = tf.reduce_sum(total_corr)/tf.cast(tf.shape(data['label'])[0], tf.float32)\n",
    "  return final_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba468e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds_batch = test_ds.batch(len(test_ds))\n",
    "iterator = iter(test_ds_batch)\n",
    "data = next(iterator)\n",
    "# indices = (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "# data['image'] = data['image'][indices]\n",
    "# data['label'] = data['label'][indices]\n",
    "print(data['image'].shape)\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "\n",
    "vals = []\n",
    "#std_devs = np.arange(0, 5, 0.1)\n",
    "std_devs = [0.0]\n",
    "for std_dev in std_devs:\n",
    "    vals.append(gauss_noise_robust_2([model], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff0127c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)\n",
    "\n",
    "_, out, _ = model2(data['image'], training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f0cd72",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model2.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-model2-multirun-mean/run-2/0/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99a72c8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals2 = []\n",
    "for std_dev in std_devs:\n",
    "    vals2.append(gauss_noise_robust_2([model2], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26a16a67",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals3 = []\n",
    "for std_dev in std_devs:\n",
    "    vals3.append(gauss_noise_robust_2([model, model2], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2b0d310",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vals)\n",
    "print(vals2)\n",
    "print(vals3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f66cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('white')\n",
    "\n",
    "dir1 = './waterbirds-model2-multirun-mean/run-2/0'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "plt.plot(std_devs, vals2, label='model2')\n",
    "plt.plot(std_devs, vals3, label='ensemble')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d44e2f13",
   "metadata": {},
   "outputs": [],
   "source": [
    "model3 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)\n",
    "\n",
    "_, out, _ = model3(data['image'], training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c710030",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model3.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-model2-noproject-multirun-mean/run-2/0/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48867247",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals4 = []\n",
    "for std_dev in std_devs:\n",
    "    vals4.append(gauss_noise_robust_2([model3], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3253207",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals5 = []\n",
    "for std_dev in std_devs:\n",
    "    vals5.append(gauss_noise_robust_2([model, model3], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b9f2115",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('white')\n",
    "\n",
    "dir1 = './waterbirds-model2-noproject-multirun-mean/run-2/0'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "plt.plot(std_devs, vals4, label='model2')\n",
    "plt.plot(std_devs, vals5, label='ensemble')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c68c9a78",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir1 = './waterbirds-model2-multirun-mean/run-2/0'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "#plt.plot(std_devs, vals2, label='model2-proj')\n",
    "plt.plot(std_devs, vals3, label='ensemble-proj')\n",
    "#plt.plot(std_devs, vals4, label='model2-rand')\n",
    "plt.plot(std_devs, vals5, label='ensemble-ind')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test_overall_3_plots.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f6b27ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model4 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292f4dcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model4.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-model2-multirun-mean/run-2/1/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c264e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals6 = []\n",
    "for std_dev in std_devs:\n",
    "    vals6.append(gauss_noise_robust_2([model4], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03c9a6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals7 = []\n",
    "for std_dev in std_devs:\n",
    "    vals7.append(gauss_noise_robust_2([model, model4], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e2d1e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('white')\n",
    "\n",
    "dir1 = './waterbirds-model2-multirun-mean/run-2/1'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "plt.plot(std_devs, vals6, label='model2')\n",
    "plt.plot(std_devs, vals7, label='ensemble')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad6b198c",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model5 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2126413a",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model5.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-model2-noproject-multirun-mean/run-2/1/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22fe572a",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals8 = []\n",
    "for std_dev in std_devs:\n",
    "    vals8.append(gauss_noise_robust_2([model5], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cde8db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals9 = []\n",
    "for std_dev in std_devs:\n",
    "    vals9.append(gauss_noise_robust_2([model, model5], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "424c6328",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('white')\n",
    "\n",
    "dir1 = './waterbirds-model2-noproject-multirun-mean/run-2/1'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "plt.plot(std_devs, vals8, label='model2')\n",
    "plt.plot(std_devs, vals9, label='ensemble')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb79bce7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir1 = './waterbirds-model2-multirun-mean/run-2/1'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "#plt.plot(std_devs, vals2, label='model2-proj')\n",
    "plt.plot(std_devs, vals7, label='ensemble-proj')\n",
    "#plt.plot(std_devs, vals4, label='model2-rand')\n",
    "plt.plot(std_devs, vals9, label='ensemble-ind')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test_overall_3_plots.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "040c247a",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model6 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63acf194",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model6.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-model2-multirun-mean/run-2/2/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0232ae43",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals10 = []\n",
    "for std_dev in std_devs:\n",
    "    vals10.append(gauss_noise_robust_2([model6], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a9725e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals11 = []\n",
    "for std_dev in std_devs:\n",
    "    vals11.append(gauss_noise_robust_2([model, model6], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d22d6042",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('white')\n",
    "\n",
    "dir1 = './waterbirds-model2-multirun-mean/run-2/2'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "plt.plot(std_devs, vals10, label='model2')\n",
    "plt.plot(std_devs, vals11, label='ensemble')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddb912e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model7 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95803319",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model7.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-model2-noproject-multirun-mean/run-2/2/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639fc128",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals12 = []\n",
    "for std_dev in std_devs:\n",
    "    vals12.append(gauss_noise_robust_2([model7], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aef44c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals13 = []\n",
    "for std_dev in std_devs:\n",
    "    vals13.append(gauss_noise_robust_2([model, model7], data, std_dev))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51d027b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('white')\n",
    "\n",
    "dir1 = './waterbirds-model2-noproject-multirun-mean/run-2/2'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "plt.plot(std_devs, vals12, label='model2')\n",
    "plt.plot(std_devs, vals13, label='ensemble')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bc36aac",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir1 = './waterbirds-model2-multirun-mean/run-2/2'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "#plt.plot(std_devs, vals2, label='model2-proj')\n",
    "plt.plot(std_devs, vals11, label='ensemble-proj')\n",
    "#plt.plot(std_devs, vals4, label='model2-rand')\n",
    "plt.plot(std_devs, vals13, label='ensemble-ind')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test_overall_3_plots.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a69570b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "vals_proj_mean = np.sum([np.array(vals3), np.array(vals7), np.array(vals11)], axis=0)/3.0\n",
    "vals_rand_mean = np.sum([np.array(vals5), np.array(vals9), np.array(vals13)], axis=0)/3.0\n",
    "\n",
    "vals3_mean = vals3 - vals_proj_mean\n",
    "vals7_mean = vals7 - vals_proj_mean\n",
    "vals11_mean = vals11 - vals_proj_mean\n",
    "\n",
    "vals5_mean = vals5 - vals_rand_mean\n",
    "vals9_mean = vals9 - vals_rand_mean\n",
    "vals13_mean = vals13 - vals_rand_mean\n",
    "\n",
    "proj_std = np.sqrt((vals3_mean**2 + vals7_mean**2 + vals11_mean**2)/3)\n",
    "rand_std = np.sqrt((vals5_mean**2 + vals9_mean**2 + vals13_mean**2)/3)\n",
    "\n",
    "vals_model2_mean = np.sum([np.array(vals2), np.array(vals6), np.array(vals10)], axis=0)/3.0\n",
    "vals2_mean = vals2 - vals_model2_mean\n",
    "vals6_mean = vals6 - vals_model2_mean\n",
    "vals10_mean = vals10 - vals_model2_mean\n",
    "model2_std = np.sqrt((vals2_mean**2 + vals6_mean**2 + vals10_mean**2)/3)\n",
    "print(vals)\n",
    "print(vals_model2_mean)\n",
    "print(model2_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7fbfca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir1 = './waterbirds-model2-multirun-mean/run-2'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "#plt.plot(std_devs, vals2, label='model2-proj')\n",
    "plt.plot(std_devs, vals_proj_mean, label='ensemble-proj')\n",
    "#plt.plot(std_devs, vals4, label='model2-rand')\n",
    "plt.plot(std_devs, vals_rand_mean, label='ensemble-ind')\n",
    "plt.legend()\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test_overall_3_plots.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7efd912",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir1 = './waterbirds-model2-multirun-mean/run-2'\n",
    "plt.plot(std_devs, vals, label='model1')\n",
    "#plt.plot(std_devs, vals2, label='model2-proj')\n",
    "plt.plot(std_devs, vals_proj_mean, label='ensemble-proj')\n",
    "#plt.plot(std_devs, vals4, label='model2-rand')\n",
    "plt.plot(std_devs, vals_rand_mean, label='ensemble-ind')\n",
    "\n",
    "#plt.errorbar(cutoffs, fin_test_mean[ind1, 0], yerr=fin_test_std[ind1, 0], fmt='none', color='blue', capsize=3, alpha=0.6, elinewidth=2)\n",
    "plt.errorbar(std_devs, vals_proj_mean, yerr=proj_std, fmt='none', color='orange', capsize=3, alpha=0.6, elinewidth=2)\n",
    "plt.errorbar(std_devs, vals_rand_mean, yerr=rand_std, fmt='none', color='green', capsize=3, alpha=0.6, elinewidth=2)\n",
    "plt.xlabel('std_dev', fontsize = 18)\n",
    "plt.ylabel('accuracy', fontsize = 18)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.locator_params(axis='y', nbins=6)\n",
    "plt.locator_params(axis='x', nbins=6)\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir1 + '/gauss_robust_test_overall_3_plots_stddev.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60548903",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "fin_content = {}\n",
    "fin_content['model1'] = vals\n",
    "fin_content['ensemble_proj'] = vals_proj_mean\n",
    "fin_content['ensemble_ind'] = vals_rand_mean\n",
    "fin_content['proj_std'] = proj_std\n",
    "fin_content['rand_std'] = rand_std\n",
    "\n",
    "with open(dir1 + '/gauss_robust.pkl', 'wb') as f:\n",
    "    pickle.dump(fin_content, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeddb890",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
