{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0849a6f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from absl import app\n",
    "from absl import flags\n",
    "from absl import logging\n",
    "import numpy as np  # pylint: disable=unused-import\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4329b379",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS = flags.FLAGS\n",
    "\n",
    "flags.DEFINE_float('batch_norm_decay', 0.9, 'Batch norm decay parameter.')\n",
    "\n",
    "flags.DEFINE_float('rr_weight', 0.0,\n",
    "                   'Weight for the redundancy reduction term.')\n",
    "\n",
    "flags.DEFINE_bool('use_rr_loss', False, 'Use redundancy reduction term.')\n",
    "\n",
    "flags.DEFINE_bool('class_specific_rr_loss', True, 'Use class specific RR loss.')\n",
    "\n",
    "flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')\n",
    "\n",
    "flags.DEFINE_float('momentum', 0.9, 'momentum for SGD.')\n",
    "\n",
    "flags.DEFINE_integer('train_batch_size', 64, 'Batch size for training.')\n",
    "\n",
    "flags.DEFINE_integer('train_epochs', 100, 'Number of epochs to train for.')\n",
    "\n",
    "flags.DEFINE_integer('train_epochs_finetune', 50,\n",
    "                     'Number of epochs to finetune for.')\n",
    "\n",
    "flags.DEFINE_integer('train_steps', 0,\n",
    "                     'Number of train steps. If  > 0, overrides train epochs')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'train_steps_finetune', 0,\n",
    "    'Number of train finetune steps. If  > 0, overrides train epochs finetune')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'eval_steps', 0,\n",
    "    'Number of steps to eval for. If not provided, evals over entire dataset.')\n",
    "\n",
    "flags.DEFINE_integer('lr_decay_gap', 50, 'Gap between decaying lr')\n",
    "\n",
    "flags.DEFINE_float('lr_decay_factor', 0.5, 'Value to decay lr by')\n",
    "\n",
    "flags.DEFINE_integer('val_batch_size', 256, 'Batch size for eval.')\n",
    "\n",
    "flags.DEFINE_integer('test_batch_size', 256, 'Batch size for eval.')\n",
    "\n",
    "flags.DEFINE_integer('num_train', -1, 'Num training examples.')\n",
    "\n",
    "flags.DEFINE_integer('num_test', -1, 'Num test examples.')\n",
    "\n",
    "flags.DEFINE_integer('num_train_finetune', 20000,\n",
    "                     'Num training examples for finetuning.')\n",
    "\n",
    "flags.DEFINE_integer('num_test_finetune', 2000,\n",
    "                     'Num test examples for finetuning.')\n",
    "\n",
    "flags.DEFINE_integer('buffer_size', 256, 'Buffer size for shuffling.')\n",
    "\n",
    "flags.DEFINE_integer('checkpoint_epochs', 1,\n",
    "                     'Number of epochs between checkpoints/summaries.')\n",
    "\n",
    "flags.DEFINE_string('dataset', 'imagenette', 'Name of a dataset.')\n",
    "\n",
    "flags.DEFINE_bool(\n",
    "    'cache_dataset', False,\n",
    "    'Whether to cache the entire dataset in memory. If the dataset is '\n",
    "    'ImageNet, this is a very bad idea, but for smaller datasets it can '\n",
    "    'improve performance.')\n",
    "\n",
    "flags.DEFINE_bool('use_dropout_pretrain', False,\n",
    "                  'Whether to use dropout on the final layer for pretraining.')\n",
    "\n",
    "flags.DEFINE_float('dropout_rate', 0.0,\n",
    "                   'Dropout rate for final layer for pretraining.')\n",
    "\n",
    "flags.DEFINE_bool(\n",
    "    'finetune_from_random', False,\n",
    "    'Whether to initialize the finetuning model using random weights.')\n",
    "\n",
    "flags.DEFINE_enum(\n",
    "    'train_mode', 'pretrain', ['pretrain', 'finetune'],\n",
    "    'The train mode controls different objectives and trainable components.')\n",
    "\n",
    "flags.DEFINE_string('platform', 'GPU', 'To be run on GPU or TPU.')\n",
    "\n",
    "flags.DEFINE_float('validation_split', 0.2,\n",
    "                   'Validation split to use while training.')\n",
    "\n",
    "flags.DEFINE_string(\n",
    "    'checkpoint', None,\n",
    "    'Loading from the given checkpoint for fine-tuning if a finetuning '\n",
    "    'checkpoint does not already exist in model_dir.')\n",
    "\n",
    "flags.DEFINE_string('optimizer', 'sgd', 'Optimizer to be used.')\n",
    "\n",
    "flags.DEFINE_bool('project_out_prev_w', False, 'Project out previous w')\n",
    "\n",
    "flags.DEFINE_bool(\n",
    "    'zero_init_logits_layer', False,\n",
    "    'If True, zero initialize layers after avg_pool for supervised learning.')\n",
    "\n",
    "flags.DEFINE_bool('use_data_aug_with_DRO', False, 'True or False')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'fine_tune_after_block', -1,\n",
    "    'The layers after which block that we will fine-tune. -1 means fine-tuning '\n",
    "    'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '\n",
    "    'just the linear head.')\n",
    "\n",
    "flags.DEFINE_integer('keep_checkpoint_max', 5,\n",
    "                     'Maximum number of checkpoints to keep.')\n",
    "\n",
    "flags.DEFINE_enum('lr_decay_type', 'step_decay', ['step_decay', 'cosine_decay', 'warmup_cosine_decay'],\n",
    "                  'Kind of decay to be used for learning rate')\n",
    "\n",
    "flags.DEFINE_enum('learning_rate_scaling', 'linear', ['linear', 'sqrt'],\n",
    "                  'Learning rate scaling to use')\n",
    "\n",
    "flags.DEFINE_integer('warmup_steps', 0, 'Warmup steps for warmup and cosine decay')\n",
    "\n",
    "flags.DEFINE_integer('warmup_epochs', 20, 'Warmup epochs for warmup and cosine decay')\n",
    "\n",
    "flags.DEFINE_boolean(\n",
    "    'global_bn', True,\n",
    "    'Whether to aggregate BN statistics across distributed cores.')\n",
    "\n",
    "flags.DEFINE_integer('proj_dim', 20, 'Output dimension of projection head')\n",
    "\n",
    "flags.DEFINE_integer(\n",
    "    'num_heads', 1,\n",
    "    'Number of heads across which to decorrelate. One head is the standard config'\n",
    ")\n",
    "\n",
    "flags.DEFINE_integer('width_multiplier', 1,\n",
    "                     'Multiplier to change width of network.')\n",
    "\n",
    "flags.DEFINE_integer('resnet_depth', 50, 'Depth of ResNet.')\n",
    "\n",
    "flags.DEFINE_float(\n",
    "    'sk_ratio', 0.,\n",
    "    'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')\n",
    "\n",
    "flags.DEFINE_float('se_ratio', 0., 'If it is bigger than 0, it will enable SE.')\n",
    "\n",
    "flags.DEFINE_float('weight_decay', 1e-4, 'weight decay to be used')\n",
    "\n",
    "flags.DEFINE_float('logit_decay', 0.0, 'Decay to be used on logits')\n",
    "\n",
    "flags.DEFINE_integer('image_size', 32, 'Input image size.')\n",
    "\n",
    "flags.DEFINE_integer('val_epochs_gap', 10, 'Epoch gap to run validation after')\n",
    "\n",
    "flags.DEFINE_integer('finetune_val_epochs_gap', 10,\n",
    "                     'Epoch gap to run validation after')\n",
    "\n",
    "flags.DEFINE_integer('val_steps_gap', 0, 'Steps gap to run validation after')\n",
    "\n",
    "flags.DEFINE_integer('finetune_val_steps_gap', 0,\n",
    "                     'Steps gap to run validation after')\n",
    "\n",
    "flags.DEFINE_boolean('use_pretrained', True, 'whether to use pretrained model')\n",
    "\n",
    "flags.DEFINE_string(\n",
    "    'path', './',\n",
    "    'path for the dataset')\n",
    "\n",
    "flags.DEFINE_bool('use_OOD_transform', False,\n",
    "                  'Use data preprocessing specific to OOD dataset')\n",
    "\n",
    "flags.DEFINE_float('clip_norm', None, 'global clip norm for the gradient')\n",
    "\n",
    "flags.DEFINE_integer('num_runs', 1,\n",
    "                     'Number of runs to average the evaluations over')\n",
    "\n",
    "flags.DEFINE_integer('num_head_layers', 1, 'Number of layers to use in head')\n",
    "\n",
    "flags.DEFINE_bool('use_early_stopping', False, 'Whether to use early stopping based on validation accuracy or not')\n",
    "\n",
    "flags.DEFINE_integer('proj_layer', 0,\n",
    "                     'Layer in head where applying projection for rr')\n",
    "\n",
    "flags.DEFINE_integer('head_dim', 512, 'Dimension of head layers')\n",
    "\n",
    "flags.DEFINE_string('model_dir', None, 'Path for loading/saving a model')\n",
    "\n",
    "flags.DEFINE_string('model_finetune_dir', None,\n",
    "                    'Path for loading/saving the model after finetuning')\n",
    "\n",
    "flags.DEFINE_bool('use_seq_rr', True, 'Whether to use rr loss sequentially')\n",
    "\n",
    "flags.DEFINE_integer('num_seq_models', 2, 'Number of sequential models to train')\n",
    "\n",
    "flags.DEFINE_bool('load_model', False, 'whether to try to load model')\n",
    "\n",
    "flags.DEFINE_bool('save_model', False, 'whether to save model')\n",
    "\n",
    "flags.DEFINE_bool('lowerbound_rr', False, 'Whether to lowerbound rr loss')\n",
    "\n",
    "flags.DEFINE_float('lowerbound_factor', 0.5, 'If lowerbound rr, by what factor of expected value')\n",
    "\n",
    "flags.DEFINE_bool('use_exp_var_loss', False, 'Use explained away variance as loss')\n",
    "\n",
    "flags.DEFINE_bool('use_MI_loss', False, 'Use MI based loss function')\n",
    "\n",
    "flags.DEFINE_integer('CIFAR_label_1', 1, 'CIFAR class 1 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('CIFAR_label_2', 9, 'CIFAR class 2 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('MNIST_label_1', 0, 'MNIST class 1 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('MNIST_label_2', 1, 'MNIST class 2 for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_float('corr_frac', 1.0, 'Correlation factor of MNIST and CIFAR for MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_bool('use_proj_head', True, 'Use a projection head in the model')\n",
    "\n",
    "flags.DEFINE_bool('normalize_MI', False, 'Normalize MI based loss')\n",
    "\n",
    "flags.DEFINE_bool('use_logit_decorr', False, 'Use logit decorrelation')\n",
    "\n",
    "flags.DEFINE_bool('use_prob_decorr', False, 'Use probability decorrelation')\n",
    "\n",
    "flags.DEFINE_bool('use_val_for_MI', False, 'Use validation set for MI based loss')\n",
    "\n",
    "flags.DEFINE_bool('use_cifar_aug', False, 'Use CIFAR augmentation in MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_bool('use_mnist_aug', False, 'Use MNIST augmentation in MNIST-CIFAR dataset')\n",
    "\n",
    "flags.DEFINE_integer('num_classes', 10, 'Number of classes')\n",
    "\n",
    "flags.DEFINE_bool('monitor_rr_grad_norms', False, 'Monitor rr gradient norms')\n",
    "\n",
    "flags.DEFINE_float('use_rr_after_frac', 0.0, 'Use rr after a fraction of steps')\n",
    "\n",
    "flags.DEFINE_bool('use_sq_MI', False, 'Use squared MI as loss instead of MI, for better gradients')\n",
    "\n",
    "flags.DEFINE_bool('use_disagr_loss', False, 'use disagreement based loss')\n",
    "\n",
    "flags.DEFINE_bool('normalize_MI_random', False, 'normalize MI by randomly shuffling probabilities')\n",
    "\n",
    "flags.DEFINE_bool('use_num_sq_MI', False, 'Use the MI loss of the form (MI^2)/y so as to get properly scaled gradients')\n",
    "\n",
    "flags.DEFINE_bool('use_stop_grad', False, 'Use stop gradient for MI normalization factor')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_loss', False, 'Use HSIC based independence test loss')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_diff', False, 'Use HSIC on logit difference')\n",
    "\n",
    "flags.DEFINE_bool('lin_scale_rr_weight', False, 'Linearly scale down rr weight as number of sequential models goes up')\n",
    "\n",
    "flags.DEFINE_integer('dataset_dim', 2, 'Dimension of LMS dataset')\n",
    "\n",
    "flags.DEFINE_integer('num_lin', 1, 'Number of linear dimensions')\n",
    "\n",
    "flags.DEFINE_integer('num_3_slabs', 1, 'Number of 3 slabs')\n",
    "\n",
    "flags.DEFINE_integer('num_5_slabs', 0, 'Number of 5 slabs')\n",
    "\n",
    "flags.DEFINE_integer('num_7_slabs', 0, 'Number of 7 slabs')\n",
    "\n",
    "flags.DEFINE_bool('use_random_transform', False, 'Use random transformation of input coordinates')\n",
    "\n",
    "flags.DEFINE_float('lin_margin', 0.1, 'Linear coordinate margin')\n",
    "\n",
    "flags.DEFINE_float('slab_margin', 0.05, 'Slab coordinate margin')\n",
    "\n",
    "flags.DEFINE_integer('fcn_layers', 3, 'Number of layers in FCN net for lms dataset')\n",
    "\n",
    "flags.DEFINE_integer('hidden_dim', 512, 'Hidden dimension for FCN')\n",
    "\n",
    "flags.DEFINE_bool('randomize_linear', False, 'Randomize linear coordinate in the dataset')\n",
    "\n",
    "flags.DEFINE_bool('randomize_slabs', False, 'Randomize slab coordinates in the dataset')\n",
    "\n",
    "flags.DEFINE_bool('turn_off_randomize_later', False, 'Turn off coordinate randomization later')\n",
    "\n",
    "flags.DEFINE_bool('use_L4_reg', False, 'Use L4 instead of L2 regularization')\n",
    "\n",
    "flags.DEFINE_bool('use_bn', False, 'Use BN in architecture')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_on_features', False, 'Use HSIC based loss on feature layers')\n",
    "\n",
    "flags.DEFINE_integer('HSIC_feature_layer', 0, 'Feature layer to use HSIC loss on')\n",
    "\n",
    "flags.DEFINE_multi_integer('HSIC_feature_layers', None, 'Feature layers to use HSIC loss on')\n",
    "\n",
    "flags.DEFINE_bool('use_all_features_HSIC', False, 'Use features at all the layers for HSIC loss')\n",
    "\n",
    "flags.DEFINE_bool('use_sq_HSIC', False, 'Square HSIC loss to manage gradients as loss goes down')\n",
    "\n",
    "flags.DEFINE_bool('use_GAP_HSIC_features', True, 'Use GAP on HSIC features')\n",
    "\n",
    "flags.DEFINE_bool('use_random_projections', False, 'Use random projections')\n",
    "\n",
    "flags.DEFINE_integer('random_proj_dim', 1, 'Random projection dimension')\n",
    "\n",
    "flags.DEFINE_bool('use_prev_logits_HSIC_features', False, 'Use logits of previous models for computing HSIC on features')\n",
    "\n",
    "flags.DEFINE_bool('use_MNIST_labels', False, 'Use MNIST labels in MNIST-CIFAR')\n",
    "\n",
    "flags.DEFINE_bool('switch_corr_later', False, 'Change correlation after first model')\n",
    "\n",
    "flags.DEFINE_bool('switch_labels_later', False, 'Switch whether to use CIFAR or MNIST labels after first model')\n",
    "\n",
    "flags.DEFINE_bool('monitor_EG_overlap', False, 'Monitor expected gradients overlap across models')\n",
    "\n",
    "flags.DEFINE_bool('monitor_robustness_measures', False, 'Monitor Gaussian, mask and RDE based robustness measures')\n",
    "\n",
    "flags.DEFINE_bool('monitor_error_diversity', False, 'Monitor error diversity')\n",
    "\n",
    "flags.DEFINE_bool('monitor_logit_correlation', False, 'Monitor logit correlation')\n",
    "\n",
    "flags.DEFINE_bool('sep_short_direct_branch', False, 'Separately make shortcut and direct branch independent of previous model')\n",
    "\n",
    "flags.DEFINE_bool('use_pretrained_model_1', False, 'Utilise a pretrained first model.')\n",
    "\n",
    "flags.DEFINE_string('pretrained_model_path', None, 'Path for first pretrained model')\n",
    "\n",
    "flags.DEFINE_multi_string('pretrained_model_paths', None, 'Paths for pretrained models')\n",
    "\n",
    "flags.DEFINE_multi_string('pretrained_checkpoint_paths', None, 'Paths for pretrained checkpoints')\n",
    "\n",
    "flags.DEFINE_bool('use_indexed_checkpoints', False, 'Use particular checkpoint index')\n",
    "\n",
    "flags.DEFINE_bool('check_tf_func', False, 'Check tf func')\n",
    "\n",
    "flags.DEFINE_bool('use_FCN', False, 'Use FCN architecture')\n",
    "\n",
    "flags.DEFINE_bool('monitor_EG_loss', False, 'Monitor EG loss')\n",
    "\n",
    "flags.DEFINE_bool('use_equal_split', False, 'Use equal split for DRO setting')\n",
    "\n",
    "flags.DEFINE_bool('use_EG_loss', False, 'Use expected gradients loss for avoiding collapse')\n",
    "\n",
    "flags.DEFINE_integer('num_ref_EG_loss', 2, 'Number of referneces in EG loss')\n",
    "\n",
    "flags.DEFINE_float('EG_loss_weight', 1e-3, 'Weight of EG loss')\n",
    "\n",
    "flags.DEFINE_bool('binary_classification', True, 'Use binary classification and logistic loss')\n",
    "\n",
    "flags.DEFINE_bool('use_color_labels', False, 'Use colors for label in color-MNIST or binary-color-MNIST')\n",
    "\n",
    "flags.DEFINE_bool('use_CNN', False, 'Use custom CNN architecture')\n",
    "\n",
    "flags.DEFINE_bool('use_HSIC_ratio', False, 'Use HSIC ratio as the loss')\n",
    "\n",
    "flags.DEFINE_string(\n",
    "    'master', 'local',\n",
    "    \"BNS name of the TensorFlow master to use. 'local' for GPU.\")\n",
    "\n",
    "flags.DEFINE_integer('project_out_rank', 0, 'Projecting certain dimensions out of input')\n",
    "\n",
    "flags.DEFINE_float('project_out_factor', 0.0, 'Projecting out factor')\n",
    "\n",
    "flags.DEFINE_float('eig_cutoff_factor', 0.0, 'Eigenvalue cutoff factor')\n",
    "\n",
    "flags.DEFINE_integer('check_ranks_max', 10, 'Check rank of 1st hidden matrix')\n",
    "\n",
    "flags.DEFINE_multi_integer('filters', [16, 32, 64], 'Filters to be used in a CNN')\n",
    "\n",
    "flags.DEFINE_multi_integer('kernel_sizes', [3, 3, 3], 'kernel sizes to be used in a CNN')\n",
    "\n",
    "flags.DEFINE_multi_integer('strides', [1, 2, 1], 'Strides to be used in a CNN')\n",
    "\n",
    "flags.DEFINE_multi_integer('project_out_vecs', [1,2], 'Number of top SVD vectors to project out')\n",
    "\n",
    "flags.DEFINE_bool('use_chizat_init', False, 'Whether to use chizat-bach initialization in head')\n",
    "\n",
    "flags.DEFINE_bool('project_out_w', False, 'Project out w from representations')\n",
    "\n",
    "flags.DEFINE_bool('use_complete_corr', False, 'Use complete correlation in DRO setting')\n",
    "\n",
    "flags.DEFINE_bool('use_complete_corr_test', False, 'Use complete correlation in DRO setting')\n",
    "\n",
    "flags.DEFINE_bool('flip_err_div_for_minority', False, 'Flip error diversity calc for minority classes')\n",
    "\n",
    "flags.DEFINE_bool('measure_feat_robust', False, 'Measure robustness of features')\n",
    "\n",
    "flags.DEFINE_float('max_gauss_noise_std', 5.0,  'Maximum gaussian noise std')\n",
    "\n",
    "flags.DEFINE_boolean('use_tpu', True, 'Should we use TPU?')\n",
    "\n",
    "flags.DEFINE_bool('check_torch_reps', False, 'Check torch reps')\n",
    "\n",
    "flags.DEFINE_boolean(\n",
    "    'train_split', 1,\n",
    "    'Use train validation split while training, If set to false, use entire training dataset'\n",
    ")\n",
    "\n",
    "flags.DEFINE_bool('finetune_only_linear_head', False, 'Finetune only linear head')\n",
    "\n",
    "flags.DEFINE_bool('standardize_mean_reps', True, 'Standardize mean of the representations')\n",
    "\n",
    "_FRAC_POISON = flags.DEFINE_float('frac_poison', 0.,\n",
    "                                  'Fraction of poisoned examples.')\n",
    "\n",
    "_TASK_ID = flags.DEFINE_enum('task_id', 'Imagenette', [\n",
    "    'Data-poisoning', 'DRO', 'Few-shot', 'CIFAR-10.2', 'CIFAR-10.2-finetune',\n",
    "    'CINIC', 'CINIC-finetune', 'Imagenette'\n",
    "], 'Specify the task that needs to be run.')\n",
    "\n",
    "_FINETUNE_ONLY_HEAD = flags.DEFINE_bool(\n",
    "    'finetune_only_head', False, 'whether to finetune head or the entire model')\n",
    "\n",
    "_TRAIN_CLASSES = flags.DEFINE_multi_integer(\n",
    "    'train_classes', [0, 1, 2, 3, 4], 'classes to train for few-shot learning')\n",
    "\n",
    "_FINETUNE_CLASSES = flags.DEFINE_multi_integer(\n",
    "    'finetune_classes', [5, 6, 7, 8, 9],\n",
    "    'classes to finetune for few-shot learning')\n",
    "\n",
    "TASK_STAGES = {\n",
    "    'Data-poisoning': ['Train'],\n",
    "    'DRO': ['Train'],\n",
    "    'Few-shot': ['Train', 'Finetune'],\n",
    "    'CIFAR-10.2': ['Train'],\n",
    "    'CIFAR-10.2-finetune': ['Train'],\n",
    "    'CINIC': ['Train'],\n",
    "    'CINIC-finetune': ['Train'],\n",
    "    'CIFAR-MNIST': ['Train'],\n",
    "    'LMS': ['Train'],\n",
    "    'MNIST': ['Train'],\n",
    "    'color-MNIST': ['Train'],\n",
    "    'Imagenette': ['Train'],\n",
    "}\n",
    "\n",
    "flags.DEFINE_string('f', '', 'kernel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "910381f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class head_model(tf.keras.Model):\n",
    "\n",
    "    def __init__(self,\n",
    "               base_model,\n",
    "               num_classes,\n",
    "               proj_dim,\n",
    "               proj_layer,\n",
    "               head_dims,\n",
    "               use_proj=True,\n",
    "               use_bn=True,\n",
    "               use_relu=True,\n",
    "               dropout_rate=0.0):\n",
    "        super(head_model, self).__init__()\n",
    "        self.base_model = base_model\n",
    "        if proj_layer > len(head_dims):\n",
    "            raise ValueError('proj layer must be less than head dimensions')\n",
    "        self.proj_head = []\n",
    "        for i in range(proj_layer):\n",
    "            self.proj_head.append(tf.keras.layers.Dense(head_dims[i]))\n",
    "            if use_bn:\n",
    "                self.proj_head.append(norm_layer())\n",
    "            if use_relu:\n",
    "                self.proj_head.append(tf.keras.layers.ReLU())\n",
    "\n",
    "        if use_proj:\n",
    "            self.proj_head.append(tf.keras.layers.Dense(proj_dim))\n",
    "\n",
    "        self.head = []\n",
    "        self.head.append(tf.keras.layers.Dropout(rate=dropout_rate))\n",
    "        for i in range(len(head_dims) - proj_layer):\n",
    "            self.head.append(tf.keras.layers.Dense(head_dims[proj_layer + i]))\n",
    "            if use_bn:\n",
    "                self.head.append(norm_layer())\n",
    "            if use_relu:\n",
    "                self.head.append(tf.keras.layers.ReLU())\n",
    "\n",
    "        if FLAGS.use_chizat_init:\n",
    "            if len(head_dims) > 1:\n",
    "                raise ValueError('Rich regime initialization only for 1 layer net')\n",
    "            else:\n",
    "                self.out = tf.keras.layers.Dense(num_classes, use_bias=False)\n",
    "        else:\n",
    "            self.out = tf.keras.layers.Dense(num_classes)\n",
    "\n",
    "        self.head_dims = head_dims\n",
    "\n",
    "    def call(self, x, training, only_head=False, only_linear_head=False, init=False, project_out_w=False, project_out_mat=None,\n",
    "           gauss_noise_feats=False, sigma=1.0, rand_sigma=False):\n",
    "        if FLAGS.use_pretrained:\n",
    "            x = self.base_model(x, training=training and not only_head)\n",
    "            if gauss_noise_feats:\n",
    "                if rand_sigma:\n",
    "                    std_devs = sigma * tf.random.uniform([tf.shape(x)[0]])\n",
    "                    final_shape = tf.constant([-1], dtype=tf.int32)\n",
    "                    final_shape = tf.concat([final_shape, tf.ones([tf.rank(x)-1], dtype=tf.int32)], axis=0)\n",
    "                    z = tf.reshape(std_devs, final_shape) * tf.random.normal(tf.shape(x))\n",
    "                else:\n",
    "                    z = sigma * tf.random.normal(tf.shape(x))\n",
    "                x = x + tf.cast(z, x.dtype)\n",
    "            feat = []\n",
    "        else:\n",
    "            x, feat = self.base_model(x, training=training and not only_head)\n",
    "        for layer in self.proj_head:\n",
    "            x = layer(x, training=training and not only_linear_head)\n",
    "        reps = x\n",
    "        if project_out_w and FLAGS.use_pretrained:\n",
    "            for ind, layer in enumerate(self.head):\n",
    "                if isinstance(layer, tf.keras.layers.Dense):\n",
    "                    for var in layer.trainable_variables:\n",
    "                        if 'kernel' in var.name:\n",
    "                            W_norm = tf.norm(var, axis=0, keepdims=True)**2\n",
    "                            sample_ind = tf.random.categorical(tf.math.log(W_norm), 1)[0,0]\n",
    "                            sample_W = tf.reshape(var[:, sample_ind], [-1,1])\n",
    "                            ctx = tf.distribute.get_replica_context()\n",
    "                            sample_W_gather = ctx.all_gather(sample_W, axis=1)\n",
    "                            sample_W = tf.reshape(sample_W_gather[:,0], [-1,1])\n",
    "                            sample_W = sample_W/tf.norm(sample_W, axis=0, keepdims=True)\n",
    "                            proj_mat = tf.eye(tf.shape(var)[0]) - tf.linalg.matmul(sample_W, tf.transpose(sample_W))\n",
    "                            x = tf.transpose(tf.linalg.matmul(proj_mat, tf.transpose(x)))\n",
    "                            break\n",
    "        elif project_out_mat is not None:\n",
    "            pinv = tf.linalg.pinv(tf.linalg.matmul(tf.transpose(project_out_mat), project_out_mat))\n",
    "            proj_mat = tf.linalg.matmul(project_out_mat, tf.linalg.matmul(pinv, tf.transpose(project_out_mat)))\n",
    "            proj_x = tf.linalg.matmul(proj_mat, tf.transpose(x))\n",
    "            x = x - tf.transpose(proj_x)\n",
    "        curr_ind = 0\n",
    "        for layer in self.head:\n",
    "            if FLAGS.use_chizat_init and isinstance(layer, tf.keras.layers.Dense):\n",
    "                if init:\n",
    "                    tf.print('assigned')\n",
    "                    z = tf.random.normal((tf.shape(x)[1]+1, self.head_dims[0]))\n",
    "                    z = z/tf.norm(z, axis=0, keepdims=True)\n",
    "                    layer.kernel.assign(z[0:tf.shape(x)[1], :])\n",
    "                    layer.bias.assign(z[tf.shape(x)[1], :])\n",
    "            x = layer(x, training=training and not only_linear_head)\n",
    "            if FLAGS.use_pretrained:\n",
    "                if isinstance(layer, tf.keras.layers.Dense):\n",
    "                    if curr_ind in FLAGS.HSIC_feature_layers or FLAGS.use_all_features_HSIC:\n",
    "                        feat.append(x)\n",
    "                    curr_ind += 1\n",
    "        if FLAGS.use_chizat_init:\n",
    "            if init:\n",
    "                tf.print('out assigned')\n",
    "                samples = tf.random.categorical(tf.math.log([[0.5, 0.5]]), tf.shape(x)[1]*FLAGS.num_classes)\n",
    "                self.out.kernel.assign(tf.reshape(2.0*tf.cast(samples, tf.float32) - 1.0, self.out.kernel.shape))\n",
    "        x = self.out(x, training=training)\n",
    "\n",
    "        if FLAGS.use_chizat_init:\n",
    "            x = x/tf.cast(self.head_dims[0], tf.float32)\n",
    "        if FLAGS.use_all_features_HSIC or -1 in FLAGS.HSIC_feature_layers:\n",
    "            feat.append(x)\n",
    "        return reps, x, feat\n",
    "\n",
    "def createmodel(num_classes,\n",
    "                head_dim,\n",
    "                head_layers,\n",
    "                proj_dim,\n",
    "                proj_layer,\n",
    "                use_proj=True,\n",
    "                resnet_base=None,\n",
    "                dropout_rate=0.0,\n",
    "                num_heads=1):\n",
    "    \"\"\"Create a model with resnet base and a linear head.\"\"\"\n",
    "    if resnet_base is None:\n",
    "        resnet_base = resnet.resnet(\n",
    "            resnet_depth=FLAGS.resnet_depth,\n",
    "            width_multiplier=FLAGS.width_multiplier,\n",
    "            cifar_stem=FLAGS.image_size <= 32)\n",
    "    if num_heads > 1:\n",
    "        model = multihead_model(\n",
    "            resnet_base,\n",
    "            num_classes,\n",
    "            proj_dim,\n",
    "            proj_layer,\n",
    "            head_dim + np.zeros(head_layers),\n",
    "            num_heads,\n",
    "            use_proj=use_proj,\n",
    "            dropout_rate=dropout_rate)\n",
    "    else:\n",
    "        model = head_model(\n",
    "            resnet_base,\n",
    "            num_classes,\n",
    "            proj_dim,\n",
    "            proj_layer,\n",
    "            head_dim + np.zeros(head_layers),\n",
    "            use_proj = use_proj,\n",
    "            dropout_rate=dropout_rate,\n",
    "            use_bn=FLAGS.use_bn)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b43c619c",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS.rr_weight = 0.0\n",
    "FLAGS.use_rr_loss = False\n",
    "FLAGS.class_specific_rr_loss = False\n",
    "FLAGS.learning_rate = 0.5\n",
    "FLAGS.momentum = 0.0\n",
    "FLAGS.train_batch_size = 256\n",
    "FLAGS.train_epochs = 300\n",
    "FLAGS.train_epochs_finetune = 50\n",
    "FLAGS.train_steps = 10\n",
    "FLAGS.train_steps_finetune = 5\n",
    "FLAGS.val_batch_size = 256\n",
    "FLAGS.test_batch_size = 256\n",
    "FLAGS.buffer_size = 10000\n",
    "FLAGS.checkpoint_epochs = 50\n",
    "FLAGS.dataset = 'imagenette'\n",
    "FLAGS.dropout_rate = 0.0\n",
    "FLAGS.platform = 'GPU'\n",
    "FLAGS.checkpoint = None\n",
    "FLAGS.optimizer = 'sgd'\n",
    "FLAGS.project_out_prev_w = False\n",
    "FLAGS.use_data_aug_with_DRO = False\n",
    "FLAGS.lr_decay_type = 'warmup_cosine_decay'\n",
    "FLAGS.learning_rate_scaling = 'linear'\n",
    "FLAGS.warmup_steps = 500\n",
    "FLAGS.warmup_epochs = 20\n",
    "FLAGS.proj_dim = 20\n",
    "FLAGS.num_heads = 1\n",
    "FLAGS.resnet_depth = 50\n",
    "FLAGS.weight_decay = 1e-4\n",
    "FLAGS.image_size = 32\n",
    "FLAGS.val_epochs_gap = 1\n",
    "FLAGS.finetune_val_epochs_gap = 1\n",
    "FLAGS.val_steps_gap = 0\n",
    "FLAGS.finetune_val_steps_gap = 0\n",
    "FLAGS.use_pretrained = True\n",
    "FLAGS.path = './'\n",
    "FLAGS.use_OOD_transform = False\n",
    "FLAGS.num_runs = 1\n",
    "FLAGS.num_head_layers = 1\n",
    "FLAGS.use_early_stopping = False\n",
    "FLAGS.proj_layer = 0\n",
    "FLAGS.head_dim = 100\n",
    "FLAGS.model_dir = './b-imagenette-test'\n",
    "FLAGS.model_finetune_dir = None\n",
    "FLAGS.use_seq_rr = False\n",
    "FLAGS.num_seq_models = 1\n",
    "FLAGS.load_model = False\n",
    "FLAGS.save_model = True\n",
    "FLAGS.use_exp_var_loss = False\n",
    "FLAGS.use_MI_loss = False\n",
    "FLAGS.CIFAR_label_1 = 1\n",
    "FLAGS.CIFAR_label_2 = 9\n",
    "FLAGS.MNIST_label_1 = 0\n",
    "FLAGS.MNIST_label_2 = 1\n",
    "FLAGS.corr_frac = 1.0\n",
    "FLAGS.use_proj_head = False\n",
    "FLAGS.num_classes = 2\n",
    "FLAGS.monitor_rr_grad_norms = False\n",
    "FLAGS.use_HSIC_loss = False\n",
    "FLAGS.use_HSIC_diff = False\n",
    "FLAGS.lin_scale_rr_weight = False\n",
    "FLAGS.dataset_dim = 2\n",
    "FLAGS.num_lin = 1\n",
    "FLAGS.num_3_slabs = 1\n",
    "FLAGS.num_5_slabs = 0\n",
    "FLAGS.num_7_slabs = 0\n",
    "FLAGS.use_random_transform = False\n",
    "FLAGS.lin_margin = 0.1\n",
    "FLAGS.slab_margin = 0.05\n",
    "FLAGS.fcn_layers = 3\n",
    "FLAGS.hidden_dim = 512\n",
    "FLAGS.randomize_linear = False\n",
    "FLAGS.randomize_slabs = False\n",
    "FLAGS.turn_off_randomize_later = False\n",
    "FLAGS.use_L4_reg = False\n",
    "FLAGS.use_bn = False\n",
    "FLAGS.use_HSIC_on_features = False\n",
    "FLAGS.HSIC_feature_layer = 0\n",
    "FLAGS.HSIC_feature_layers = [0,1]\n",
    "FLAGS.use_all_features_HSIC = False\n",
    "FLAGS.use_sq_HSIC = False\n",
    "FLAGS.use_GAP_HSIC_features = True\n",
    "FLAGS.use_random_projections = False\n",
    "FLAGS.random_proj_dim = 1\n",
    "FLAGS.use_prev_logits_HSIC_features = False\n",
    "FLAGS.use_MNIST_labels = False\n",
    "FLAGS.monitor_EG_overlap = False\n",
    "FLAGS.monitor_robustness_measures = True\n",
    "FLAGS.monitor_error_diversity = True\n",
    "FLAGS.monitor_logit_correlation = True\n",
    "FLAGS.use_pretrained_model_1 = False\n",
    "FLAGS.pretrained_model_path = None\n",
    "FLAGS.pretrained_model_paths = None\n",
    "FLAGS.pretrained_checkpoint_paths = None\n",
    "FLAGS.use_indexed_checkpoints = False\n",
    "FLAGS.use_FCN = False\n",
    "FLAGS.use_EG_loss = False\n",
    "FLAGS.num_ref_EG_loss = 2\n",
    "FLAGS.EG_loss_weight = 1e-3\n",
    "FLAGS.binary_classification = True\n",
    "FLAGS.use_CNN = False\n",
    "FLAGS.master = 'local'\n",
    "FLAGS.project_out_rank = 0\n",
    "FLAGS.check_ranks_max = 30\n",
    "FLAGS.filters = [16, 32, 64]\n",
    "FLAGS.kernel_sizes = [3, 3, 3]\n",
    "FLAGS.strides = [1, 2, 1]\n",
    "FLAGS.project_out_vecs = [1,2]\n",
    "FLAGS.use_chizat_init = True\n",
    "FLAGS.project_out_w = False\n",
    "FLAGS.use_complete_corr = False\n",
    "FLAGS.use_complete_corr_test = False\n",
    "FLAGS.flip_err_div_for_minority = True\n",
    "FLAGS.measure_feat_robust = True\n",
    "FLAGS.max_gauss_noise_std = 5.0\n",
    "FLAGS.use_tpu = False\n",
    "FLAGS.check_torch_reps = False\n",
    "FLAGS.train_split = True\n",
    "FLAGS.finetune_only_linear_head = False\n",
    "FLAGS.task_id = 'Imagenette'\n",
    "FLAGS.finetune_only_head = True\n",
    "FLAGS.standardize_mean_reps = True\n",
    "\n",
    "_SHUFFLE_BUFFER_SIZE = flags.DEFINE_integer('shuffle_buffer_size', 1000,\n",
    "                                            'Buffer size for shuffling.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d40920f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "FLAGS(sys.argv)\n",
    "\n",
    "FLAGS.binary_classification=False\n",
    "FLAGS.num_classes = 10\n",
    "FLAGS.use_chizat_init = False\n",
    "FLAGS.dataset = 'imagenette'\n",
    "\n",
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0e7abb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-he-init-model2-noproject-multirun-mean/run-1/0/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)\n",
    "print(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47ce804c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imagenette_train_transform(x):\n",
    "    mean = [0.485, 0.456, 0.406]\n",
    "    std = [0.229, 0.224, 0.225]\n",
    "    image = tf.image.resize(x['image'], size=[256, 256])\n",
    "    image = tf.image.random_crop(image, size=(224, 224, 3))\n",
    "    image = tf.image.random_flip_left_right(image)\n",
    "    image = image/255.0\n",
    "    image = (image - mean)/std\n",
    "    return {'image': image, 'label': x['label']}\n",
    "\n",
    "def imagenette_val_transform(x):\n",
    "    mean = [0.485, 0.456, 0.406]\n",
    "    std = [0.229, 0.224, 0.225]\n",
    "    image = tf.image.resize(x['image'], size=[256,256])\n",
    "    image = tf.image.central_crop(image, central_fraction=224.0/256.0)\n",
    "    image = image/255.0\n",
    "    image = (image - mean) / std\n",
    "\n",
    "    return {'image': image, 'label': x['label']}\n",
    "\n",
    "def load_imagenette_reps(path, binary=True, batched=True):\n",
    "  if binary:\n",
    "    filepath = os.path.join(path, 'b-imagenette.pkl')\n",
    "  else:\n",
    "    filepath = os.path.join(path, 'imagenette.pkl')\n",
    "  with open(filepath, 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "  \n",
    "  train_data = data['train']\n",
    "  train_image = train_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    train_image = train_image - np.mean(train_image, axis=0, keepdims=True)\n",
    "  train_label = train_data['labels']\n",
    "  train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': train_image,\n",
    "      'label': train_label,\n",
    "  })\n",
    "  train_ds = train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  val_data = data['val']\n",
    "  val_image = val_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    val_image = val_image - np.mean(val_image, axis=0, keepdims=True)\n",
    "  val_label = val_data['labels']\n",
    "  val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': val_image,\n",
    "      'label': val_label,\n",
    "  })\n",
    "  val_ds = val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  #use validation as test\n",
    "  test_ds = val_ds\n",
    "  \n",
    "  if binary:\n",
    "    return train_ds, val_ds, test_ds, train_image.shape[0]/100\n",
    "  else:\n",
    "    return train_ds, val_ds, test_ds, train_image.shape[0]/20\n",
    "\n",
    "def load_mnistcifar_reps(path, batched=True):\n",
    "  filepath = os.path.join(path, 'cifar-mnist.pkl')\n",
    "  with tf.io.gfile.GFile(filepath, 'rb') as fobj:\n",
    "    data = pickle.load(fobj)\n",
    "  \n",
    "  if FLAGS.train_split:\n",
    "    train_data = data['train_split']\n",
    "  else:\n",
    "    train_data = data['train']\n",
    "  train_image = train_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    train_image = train_image - np.mean(train_image, axis=0, keepdims=True)\n",
    "  train_label = train_data['labels']\n",
    "  train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': train_image,\n",
    "      'label': train_label,\n",
    "  })\n",
    "  train_ds = train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  val_data = data['val']\n",
    "  val_image = val_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    val_image = val_image - np.mean(val_image, axis=0, keepdims=True)\n",
    "  val_label = val_data['labels']\n",
    "  val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': val_image,\n",
    "      'label': val_label,\n",
    "  })\n",
    "  val_ds = val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "\n",
    "  test_data = data['test']\n",
    "  test_image = test_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    test_image = test_image - np.mean(test_image, axis=0, keepdims=True)\n",
    "  test_label = test_data['labels']\n",
    "  test_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': test_image,\n",
    "      'label': test_label,\n",
    "  })\n",
    "  test_ds = test_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  if FLAGS.train_split:\n",
    "    OOD_train_data = data['OOD_train_split']\n",
    "  else:\n",
    "    OOD_train_data = data['OOD_train']\n",
    "  OOD_train_image = OOD_train_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    OOD_train_image = OOD_train_image - np.mean(OOD_train_image, axis=0, keepdims=True)\n",
    "  OOD_train_label = OOD_train_data['labels']\n",
    "  OOD_train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': OOD_train_image,\n",
    "      'label': OOD_train_label,\n",
    "  })\n",
    "  OOD_train_ds = OOD_train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  OOD_val_data = data['OOD_val']\n",
    "  OOD_val_image = OOD_val_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    OOD_val_image = OOD_val_image - np.mean(OOD_val_image, axis=0, keepdims=True)\n",
    "  OOD_val_label = OOD_val_data['labels']\n",
    "  OOD_val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': OOD_val_image,\n",
    "      'label': OOD_val_label,\n",
    "  })\n",
    "  OOD_val_ds = OOD_val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "\n",
    "  OOD_test_data = data['OOD_test']\n",
    "  OOD_test_image = OOD_test_data['images']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    OOD_test_image = OOD_test_image - np.mean(OOD_test_image, axis=0, keepdims=True)\n",
    "  OOD_test_label = OOD_test_data['labels']\n",
    "  OOD_test_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': OOD_test_image,\n",
    "      'label': OOD_test_label,\n",
    "  })\n",
    "  OOD_test_ds = OOD_test_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "    \n",
    "  return train_ds, val_ds, test_ds, OOD_train_ds, OOD_val_ds, OOD_test_ds, train_image.shape[0], OOD_train_image.shape[0]\n",
    "\n",
    "def load_waterbirds_reps_2(path):\n",
    "  filepath = os.path.join(path, 'waterbirds_train.pkl')\n",
    "  with open(filepath, 'rb') as f:\n",
    "    train_data = pickle.load(f)\n",
    "  \n",
    "  if FLAGS.use_complete_corr:\n",
    "    train_data = train_data['waterbirds_train_complete_corr']\n",
    "  else:\n",
    "    train_data = train_data['waterbirds_train']\n",
    "  train_image = train_data['inputs']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    train_image = train_image - np.mean(train_image, axis=0, keepdims=True)\n",
    "  train_label = train_data['labels']\n",
    "  train_metadata = train_data['metadata']\n",
    "  train_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': train_image,\n",
    "      'label': train_label,\n",
    "      'metadata': train_metadata,\n",
    "  })\n",
    "  train_ds = train_ds.repeat(-1).shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  \n",
    "  filepath = os.path.join(path, 'waterbirds_val_test.pkl')\n",
    "  with open(filepath, 'rb') as f:\n",
    "    val_test_data = pickle.load(f)\n",
    "  \n",
    "  if FLAGS.use_complete_corr_test:\n",
    "    val_data = val_test_data['waterbirds_val_complete_corr']\n",
    "  else:\n",
    "    val_data = val_test_data['waterbirds_val']\n",
    "  val_image = val_data['inputs']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    val_image = val_image - np.mean(val_image, axis=0, keepdims=True)\n",
    "  val_label = val_data['labels']\n",
    "  val_metadata = val_data['metadata']\n",
    "  val_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': val_image,\n",
    "      'label': val_label,\n",
    "      'metadata': val_metadata,\n",
    "  })\n",
    "  val_ds = val_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  \n",
    "  if FLAGS.use_complete_corr_test:\n",
    "    test_data = val_test_data['waterbirds_test_complete_corr']\n",
    "  else:\n",
    "    test_data = val_test_data['waterbirds_test']\n",
    "  test_image = test_data['inputs']\n",
    "  if FLAGS.standardize_mean_reps:\n",
    "    test_image = test_image - np.mean(test_image, axis=0, keepdims=True)\n",
    "  test_label = test_data['labels']\n",
    "  test_metadata = test_data['metadata']\n",
    "  test_ds = tf.data.Dataset.from_tensor_slices({\n",
    "      'image': test_image,\n",
    "      'label': test_label,\n",
    "      'metadata': test_metadata,\n",
    "  })\n",
    "  test_ds = test_ds.shuffle(\n",
    "      _SHUFFLE_BUFFER_SIZE.value)\n",
    "  \n",
    "  if FLAGS.use_complete_corr:\n",
    "    return train_ds, val_ds, test_ds, 4555\n",
    "  else:\n",
    "    return train_ds, val_ds, test_ds, 4795"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "194ef3ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "FLAGS.train_split = True\n",
    "FLAGS.use_complete_corr = False\n",
    "FLAGS.use_complete_corr_test = False\n",
    "\n",
    "#train_ds, val_ds, test_ds, train_len = load_waterbirds_reps_2('./data_reps')\n",
    "train_ds, val_ds, test_ds, train_len = load_imagenette_reps('./data_reps', binary=False, batched=False)\n",
    "#train_ds, val_ds, test_ds, OOD_train_ds, OOD_val_ds, OOD_test_ds, train_len, OOD_train_len = load_mnistcifar_reps('./data_reps')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21278f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_00_frac = 3498/4795\n",
    "train_01_frac = 184/4795\n",
    "train_10_frac = 56/4795\n",
    "train_11_frac = 1057/4795\n",
    "\n",
    "train_ds_batch = train_ds.batch(2048)\n",
    "iterator = iter(train_ds_batch)\n",
    "data = next(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00327064",
   "metadata": {},
   "outputs": [],
   "source": [
    "_,out,_ = model(data['image'], training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(out, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "    corr_00 = (corr_00) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "    corr_01 = (corr_01) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "    corr_10 = (corr_10) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "    corr_11 = (corr_11) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc23e5ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "for var in model.layers[2].trainable_variables:\n",
    "    if 'kernel' in var.name:\n",
    "        curr_var = var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2bb0ef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "H = tf.linalg.matmul(curr_var, tf.transpose(curr_var))\n",
    "e,v = tf.linalg.eigh(H)\n",
    "print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c7bd895",
   "metadata": {},
   "outputs": [],
   "source": [
    "for var in model.layers[4].trainable_variables:\n",
    "    print(var.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3398ed62",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Obtain a projection matrix via optimization\n",
    "\n",
    "v = tf.random.normal((2048, 15))\n",
    "v = v/tf.linalg.norm(v, axis=0, keepdims=True)\n",
    "v = tf.Variable(v)\n",
    "optimizer = tf.keras.optimizers.SGD(\n",
    "                  learning_rate=0.1)\n",
    "reg = 5.0\n",
    "\n",
    "@tf.function\n",
    "def train_step(data):\n",
    "    reps = model.layers[0](data['image'], training=False)\n",
    "    if FLAGS.binary_classification:\n",
    "        loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)\n",
    "        loss_fn_smooth = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM,\n",
    "                                                           label_smoothing=1.0)\n",
    "    else:\n",
    "        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)\n",
    "        loss_fn_smooth = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM,\n",
    "                                                                label_smoothing=1.0)\n",
    "    with tf.GradientTape() as tape:\n",
    "        p_inv = tf.linalg.pinv(v)\n",
    "        proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "        #proj_mat = tf.linalg.matmul(v, tf.transpose(v))/(tf.linalg.norm(v)**2)\n",
    "        reps1 = reps - tf.linalg.matmul(reps, proj_mat)\n",
    "        reps2 = tf.linalg.matmul(reps, proj_mat)\n",
    "        for ind, layer in enumerate(model.layers):\n",
    "            if ind > 0:\n",
    "                reps1 = layer(reps1, training=False)\n",
    "                reps2 = layer(reps2, training=False)\n",
    "                reps = layer(reps, training=False)\n",
    "        if FLAGS.binary_classification:\n",
    "            loss = loss_fn_smooth(tf.zeros(data['label'].shape), reps1)\n",
    "            #loss += reg*loss_fn(data['label'], reps2)\n",
    "            pred_1 = tf.nn.sigmoid(reps)\n",
    "            loss += reg*tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(pred_1, reps2))\n",
    "        else:\n",
    "            loss = loss_fn_smooth(tf.zeros([data['label'].shape[0], FLAGS.num_classes]), reps1)\n",
    "            #one_hot_enc = tf.one_hot(data['label'], FLAGS.num_classes)\n",
    "            #tf.print(one_hot_enc)\n",
    "            #tf.print(one_hot_enc.shape)\n",
    "            #loss += reg*loss_fn(one_hot_enc, reps2)\n",
    "            pred = tf.nn.softmax(reps, axis=1)\n",
    "            #tf.print(pred)\n",
    "            #tf.print(tf.nn.softmax_cross_entropy_with_logits(pred, reps2))\n",
    "            loss += reg*tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, reps2))\n",
    "        #labels = tf.reshape(tf.random.categorical(tf.zeros([1,FLAGS.num_classes]), data['label'].shape[0]), [-1])\n",
    "        #tf.print(labels)\n",
    "        #loss = loss_fn(data['label'], reps1)\n",
    "    grads = tape.gradient(loss, [v])\n",
    "    optimizer.apply_gradients(zip(grads, [v]))\n",
    "    if FLAGS.binary_classification:\n",
    "        corr = tf.cast(tf.reshape(reps1, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "        corr = tf.argmax(reps1, axis=1) == data['label']\n",
    "    tf.print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "    if FLAGS.binary_classification:\n",
    "        corr = tf.cast(tf.reshape(reps2, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "        corr = tf.argmax(reps2, axis=1) == data['label']\n",
    "    tf.print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "    tf.print(loss)\n",
    "\n",
    "train_ds_batch = train_ds.batch(256)\n",
    "iterator = iter(train_ds_batch)\n",
    "steps = 2000\n",
    "for step in range(steps):\n",
    "    data = next(iterator)\n",
    "    train_step(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60fe55da",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)\n",
    "reps = model.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = reps_rand - tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "\n",
    "print(test_00)\n",
    "print(test_01)\n",
    "print(test_10)\n",
    "print(test_11)\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d917272",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)\n",
    "reps = model.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = reps - tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7faf4406",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds_batch = test_ds.batch(len(test_ds))\n",
    "iterator = iter(test_ds_batch)\n",
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "101cada0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data['image'].shape)\n",
    "_,out,_ = model(data['image'], training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(out, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72fe3468",
   "metadata": {},
   "outputs": [],
   "source": [
    "reps = model.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = reps_rand - tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f761a9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf.linalg.norm(out-reps)/tf.linalg.norm(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5691fd30",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = tf.reshape(2*data['label']-1, [-1,1])\n",
    "z = tf.cast(y, tf.float32)*(out-reps)\n",
    "print(z.shape)\n",
    "print(tf.reduce_sum(z)/tf.reduce_sum(tf.abs(out-reps)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90a913ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "reps = model.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = reps - tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5063d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf.linalg.norm(out-reps)/tf.linalg.norm(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00160c7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#FLAGS.binary_classification=True\n",
    "#FLAGS.num_classes = 1\n",
    "\n",
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model2 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f25fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model2.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-he-init-model2-noproject-multirun-mean/run-1/1/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)\n",
    "print(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6d8c8b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = tf.random.normal((2048, 15))\n",
    "v = v/tf.linalg.norm(v, axis=0, keepdims=True)\n",
    "v = tf.Variable(v)\n",
    "optimizer = tf.keras.optimizers.SGD(\n",
    "                  learning_rate=0.1)\n",
    "reg = 5.0\n",
    "\n",
    "@tf.function\n",
    "def train_step(data):\n",
    "    reps = model2.layers[0](data['image'], training=False)\n",
    "    if FLAGS.binary_classification:\n",
    "        loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)\n",
    "        loss_fn_smooth = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM,\n",
    "                                                           label_smoothing=1.0)\n",
    "    else:\n",
    "        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)\n",
    "        loss_fn_smooth = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM,\n",
    "                                                                label_smoothing=1.0)\n",
    "    with tf.GradientTape() as tape:\n",
    "        p_inv = tf.linalg.pinv(v)\n",
    "        proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "        #proj_mat = tf.linalg.matmul(v, tf.transpose(v))/(tf.linalg.norm(v)**2)\n",
    "        reps1 = reps - tf.linalg.matmul(reps, proj_mat)\n",
    "        reps2 = tf.linalg.matmul(reps, proj_mat)\n",
    "        for ind, layer in enumerate(model2.layers):\n",
    "            if ind > 0:\n",
    "                reps1 = layer(reps1, training=False)\n",
    "                reps2 = layer(reps2, training=False)\n",
    "        if FLAGS.binary_classification:\n",
    "            loss = loss_fn_smooth(tf.zeros(data['label'].shape), reps1)\n",
    "            loss += reg*loss_fn(data['label'], reps2)\n",
    "        else:\n",
    "            loss = loss_fn_smooth(tf.zeros([data['label'].shape[0], FLAGS.num_classes]), reps1)\n",
    "            one_hot_enc = tf.one_hot(data['label'], FLAGS.num_classes)\n",
    "            #tf.print(one_hot_enc)\n",
    "            #tf.print(one_hot_enc.shape)\n",
    "            loss += reg*loss_fn(one_hot_enc, reps2)\n",
    "        #labels = tf.reshape(tf.random.categorical(tf.zeros([1,FLAGS.num_classes]), data['label'].shape[0]), [-1])\n",
    "        #tf.print(labels)\n",
    "        #loss = loss_fn(data['label'], reps1)\n",
    "    grads = tape.gradient(loss, [v])\n",
    "    optimizer.apply_gradients(zip(grads, [v]))\n",
    "    if FLAGS.binary_classification:\n",
    "        corr = tf.cast(tf.reshape(reps1, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "        corr = tf.argmax(reps1, axis=1) == data['label']\n",
    "    tf.print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "    if FLAGS.binary_classification:\n",
    "        corr = tf.cast(tf.reshape(reps2, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "        corr = tf.argmax(reps2, axis=1) == data['label']\n",
    "    tf.print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "    tf.print(loss)\n",
    "\n",
    "train_ds_batch = train_ds.batch(256)\n",
    "iterator = iter(train_ds_batch)\n",
    "steps = 2000\n",
    "for step in range(steps):\n",
    "    data = next(iterator)\n",
    "    train_step(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66d6bd2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)\n",
    "reps = model2.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model2.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = reps_rand - tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model2.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b235aa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)\n",
    "reps = model2.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = reps - tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model2.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model2.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1acc87",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds_batch = test_ds.batch(len(test_ds))\n",
    "iterator = iter(test_ds_batch)\n",
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51982ef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data['image'].shape)\n",
    "\n",
    "_,out,_ = model2(data['image'], training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(out, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf610171",
   "metadata": {},
   "outputs": [],
   "source": [
    "reps = model2.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model2.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = reps_rand - tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model2.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "739886f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf.linalg.norm(out-reps)/tf.linalg.norm(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b31654d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = tf.reshape(2*data['label']-1, [-1,1])\n",
    "z = tf.cast(y, tf.float32)*(out-reps)\n",
    "print(z.shape)\n",
    "print(tf.reduce_sum(z)/tf.reduce_sum(tf.abs(out-reps)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3584b7d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "reps = model2.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = reps - tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model2.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model2.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2bdaf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf.linalg.norm(out-reps)/tf.linalg.norm(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d5bf6f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#FLAGS.binary_classification=True\n",
    "#FLAGS.num_classes = 1\n",
    "\n",
    "base_model = tf.keras.Sequential()\n",
    "base_model.add(tf.keras.layers.Layer())\n",
    "model3 = createmodel(\n",
    "                  FLAGS.num_classes,\n",
    "                  100,\n",
    "                  1,\n",
    "                  100,\n",
    "                  0,\n",
    "                  resnet_base=base_model,\n",
    "                  dropout_rate=0.0,\n",
    "                  num_heads=1,\n",
    "                  use_proj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e076928d",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_vars = {}\n",
    "for i in range(len(model.layers)):\n",
    "    if i != 0:\n",
    "      save_vars['model{}'.format(str(i))] = model3.layers[i]\n",
    "ckpt = tf.train.Checkpoint(**save_vars)\n",
    "manager = tf.train.CheckpointManager(\n",
    "    ckpt, directory=\"./imagenette-he-init-model2-noproject-multirun-mean/run-1/2/1\", max_to_keep=1)\n",
    "status = ckpt.restore(manager.latest_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90cb75a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = tf.random.normal((2048, 15))\n",
    "v = v/tf.linalg.norm(v, axis=0, keepdims=True)\n",
    "v = tf.Variable(v)\n",
    "optimizer = tf.keras.optimizers.SGD(\n",
    "                  learning_rate=0.1)\n",
    "reg = 5.0\n",
    "\n",
    "@tf.function\n",
    "def train_step(data):\n",
    "    reps = model3.layers[0](data['image'], training=False)\n",
    "    if FLAGS.binary_classification:\n",
    "        loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)\n",
    "        loss_fn_smooth = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM,\n",
    "                                                           label_smoothing=1.0)\n",
    "    else:\n",
    "        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)\n",
    "        loss_fn_smooth = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM,\n",
    "                                                                label_smoothing=1.0)\n",
    "    with tf.GradientTape() as tape:\n",
    "        p_inv = tf.linalg.pinv(v)\n",
    "        proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "        #proj_mat = tf.linalg.matmul(v, tf.transpose(v))/(tf.linalg.norm(v)**2)\n",
    "        reps1 = reps - tf.linalg.matmul(reps, proj_mat)\n",
    "        reps2 = tf.linalg.matmul(reps, proj_mat)\n",
    "        for ind, layer in enumerate(model3.layers):\n",
    "            if ind > 0:\n",
    "                reps1 = layer(reps1, training=False)\n",
    "                reps2 = layer(reps2, training=False)\n",
    "                reps = layer(reps, training=False)\n",
    "        if FLAGS.binary_classification:\n",
    "            loss = loss_fn_smooth(tf.zeros(data['label'].shape), reps1)\n",
    "            #loss += reg*loss_fn(data['label'], reps2)\n",
    "            pred_1 = tf.nn.sigmoid(reps)\n",
    "            loss += reg*tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(pred_1, reps2))\n",
    "        else:\n",
    "            loss = loss_fn_smooth(tf.zeros([data['label'].shape[0], FLAGS.num_classes]), reps1)\n",
    "            one_hot_enc = tf.one_hot(data['label'], FLAGS.num_classes)\n",
    "            #tf.print(one_hot_enc)\n",
    "            #tf.print(one_hot_enc.shape)\n",
    "            loss += reg*loss_fn(one_hot_enc, reps2)\n",
    "        #labels = tf.reshape(tf.random.categorical(tf.zeros([1,FLAGS.num_classes]), data['label'].shape[0]), [-1])\n",
    "        #tf.print(labels)\n",
    "        #loss = loss_fn(data['label'], reps1)\n",
    "    grads = tape.gradient(loss, [v])\n",
    "    optimizer.apply_gradients(zip(grads, [v]))\n",
    "    if FLAGS.binary_classification:\n",
    "        corr = tf.cast(tf.reshape(reps1, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "        corr = tf.argmax(reps1, axis=1) == data['label']\n",
    "    tf.print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "    if FLAGS.binary_classification:\n",
    "        corr = tf.cast(tf.reshape(reps2, [-1]) > 0, tf.int64) == data['label']\n",
    "    else:\n",
    "        corr = tf.argmax(reps2, axis=1) == data['label']\n",
    "    tf.print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "    tf.print(loss)\n",
    "\n",
    "train_ds_batch = train_ds.batch(256)\n",
    "iterator = iter(train_ds_batch)\n",
    "steps = 2000\n",
    "for step in range(steps):\n",
    "    data = next(iterator)\n",
    "    train_step(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60412df9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)\n",
    "reps = model3.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model3.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = reps_rand - tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model3.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8659ad48",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)\n",
    "reps = model3.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = reps - tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model3.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model3.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bb1a568",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds_batch = test_ds.batch(len(test_ds))\n",
    "iterator = iter(test_ds_batch)\n",
    "data = next(iterator)\n",
    "data_rand = {}\n",
    "p = np.random.permutation(data['image'].shape[0])\n",
    "data_rand['image'] = tf.gather(data['image'], p)\n",
    "data_rand['label'] = tf.gather(data['label'], p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ac585fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data['image'].shape)\n",
    "\n",
    "_,out,_ = model3(data['image'], training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(out, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(out, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c237b662",
   "metadata": {},
   "outputs": [],
   "source": [
    "reps = model3.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model3.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = reps_rand - tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model3.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec0bfb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf.linalg.norm(out-reps)/tf.linalg.norm(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e22e6683",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = tf.reshape(2*data['label']-1, [-1,1])\n",
    "z = tf.cast(y, tf.float32)*(out-reps)\n",
    "print(z.shape)\n",
    "print(tf.reduce_sum(z)/tf.reduce_sum(tf.abs(out-reps)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a8d930c",
   "metadata": {},
   "outputs": [],
   "source": [
    "reps = model3.layers[0](data['image'], training=False)\n",
    "p_inv = tf.linalg.pinv(v)\n",
    "proj_mat = tf.linalg.matmul(v, p_inv)\n",
    "reps = reps - tf.linalg.matmul(reps, tf.transpose(proj_mat))\n",
    "reps_rand = model3.layers[0](data_rand['image'], training=False)\n",
    "reps_rand = tf.linalg.matmul(reps_rand, tf.transpose(proj_mat))\n",
    "reps = reps + reps_rand\n",
    "for ind, layer in enumerate(model3.layers):\n",
    "    if ind > 0:\n",
    "        reps = layer(reps, training=False)\n",
    "if FLAGS.binary_classification:\n",
    "    corr = tf.cast(tf.reshape(reps, [-1]) > 0, tf.int64) == data['label']\n",
    "else:\n",
    "    corr = tf.argmax(reps, axis=1) == data['label']\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32)))\n",
    "print(tf.reduce_sum(tf.cast(corr, tf.float32))/data['label'].shape[0])\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    test_00 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_01 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==0) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    test_10 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==0), tf.float32))\n",
    "    test_11 = tf.reduce_sum(tf.cast((data['metadata'][:,0]==1) & (data['metadata'][:,1]==1), tf.float32))\n",
    "    \n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    corr_00 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==0)\n",
    "    corr_00_frac = tf.reduce_sum(tf.cast(corr_00, tf.float32))/test_00\n",
    "    corr_01 = (corr) & (data['metadata'][:,0]==0) & (data['metadata'][:,1]==1)\n",
    "    corr_01_frac = tf.reduce_sum(tf.cast(corr_01, tf.float32))/test_01\n",
    "    corr_10 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==0)\n",
    "    corr_10_frac = tf.reduce_sum(tf.cast(corr_10, tf.float32))/test_10\n",
    "    corr_11 = (corr) & (data['metadata'][:,0]==1) & (data['metadata'][:,1]==1)\n",
    "    corr_11_frac = tf.reduce_sum(tf.cast(corr_11, tf.float32))/test_11\n",
    "\n",
    "if FLAGS.dataset=='waterbirds':\n",
    "    print(corr_00_frac*train_00_frac + corr_01_frac*train_01_frac + corr_10_frac*train_10_frac + corr_11_frac*train_11_frac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35f2ca3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf.linalg.norm(out-reps)/tf.linalg.norm(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "533c3a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [79.52, 80.57, 79.82]\n",
    "print(np.mean(x))\n",
    "print(np.std(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a8db13",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [68.31, 69.66, 66.78]\n",
    "print(np.mean(x))\n",
    "print(np.std(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fddaa06",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [54.57, 52.14, 61.27]\n",
    "print(np.mean(x))\n",
    "print(np.std(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e599b1e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [12.54, 12.46, 10.75]\n",
    "print(np.mean(x))\n",
    "print(np.std(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d52920f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [129.59, 130.49, 141.51]\n",
    "print(np.mean(x))\n",
    "print(np.std(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13133a78",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
