{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uGMmryQKaVkp"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pandas as pd\n",
        "import gdown\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_addons as tfa\n",
        "from sklearn.metrics import roc_curve, auc\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sn7vSEKWKl5d"
      },
      "outputs": [],
      "source": [
        "path = os.getcwd()\n",
        "\n",
        "# Check if running in Colab\n",
        "try:\n",
        "  from google.colab import drive\n",
        "  IN_COLAB=True\n",
        "  print(\"Running in Colab\")\n",
        "  # Mount Google Drive\n",
        "  drive.mount('/content/drive', force_remount=True)\n",
        "  # Change directory\n",
        "  %cd /content/drive/MyDrive/Class_Distribution_Shifts_in_Zero_Shot_Learning_Learning_Robust_Representations\n",
        "except:\n",
        "  IN_COLAB=False\n",
        "  print(\"Running locally\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from pairs import distinct_pairs_func, make_pairs\n",
        "from algorithm import *\n",
        "from synthetic_data import *"
      ],
      "metadata": {
        "id": "GzJcicZA111S"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0wGVBdB_QlCM"
      },
      "source": [
        "### 1. Load data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "95obuxuHWMeP"
      },
      "outputs": [],
      "source": [
        "# Get the index of the dataset files\n",
        "! wget https://raw.githubusercontent.com/tensorflow/datasets/master/tensorflow_datasets/datasets/celeb_a/checksums.tsv\n",
        "urls = pd.read_csv('checksums.tsv', sep='\\t', names=['url', 'size', 'checksum', 'filename'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c-O-b9BVYGmR"
      },
      "outputs": [],
      "source": [
        "# download the files manually. If one of them says \"access denied\",\n",
        "# you can download it from the link, and upload it to colab as you see fit - directly or to Google Drive\n",
        "for _, row in urls.iterrows():\n",
        "    if row.filename not in os.listdir():\n",
        "        gdown.download(row.url, row.filename, quiet=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p4yPccmYQpQY"
      },
      "outputs": [],
      "source": [
        "gdown.download('https://drive.google.com/uc?export=download&id=1roEIMXWh8rxneYlSGGSkit2-adkb0oxC')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EgyZy11aZATn"
      },
      "outputs": [],
      "source": [
        "! mkdir -p ~/tensorflow_datasets/downloads/manual\n",
        "! mv list_eval_partition.txt ~/tensorflow_datasets/downloads/manual\n",
        "! mv img_align_celeba.zip ~/tensorflow_datasets/downloads/manual\n",
        "! mv list_attr_celeba.txt ~/tensorflow_datasets/downloads/manual\n",
        "! mv identity_CelebA.txt ~/tensorflow_datasets/downloads/manual\n",
        "! mv list_landmarks_align_celeba.txt ~/tensorflow_datasets/downloads/manual"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "re5rpR70BFpE"
      },
      "outputs": [],
      "source": [
        "celeb_a_builder = tfds.builder('celeb_a', version='2.1.0', try_gcs=False)\n",
        "celeb_a_builder.download_and_prepare()\n",
        "celeb_a_data = celeb_a_builder.as_dataset()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OwK7hAg97Z7w"
      },
      "outputs": [],
      "source": [
        "ATTR_KEY = \"attributes\"\n",
        "IMAGE_KEY = \"image\"\n",
        "LABEL_KEY = \"identity\"\n",
        "GROUP_KEY = \"Blond_Hair\"\n",
        "IMAGE_SIZE = 45"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gEFtLUbl7aLW"
      },
      "outputs": [],
      "source": [
        "def preprocess_input_dict(feat_dict):\n",
        "  # Separate out the image and target variable from the feature dictionary.\n",
        "  image = feat_dict[IMAGE_KEY]\n",
        "  label = feat_dict[LABEL_KEY]\n",
        "  group = feat_dict[ATTR_KEY][GROUP_KEY]\n",
        "\n",
        "  # Resize and normalize image.\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])\n",
        "  image /= 255.0\n",
        "\n",
        "  # Cast label and group to float32.\n",
        "  label = label\n",
        "  group = tf.cast(group, tf.float32)\n",
        "\n",
        "  feat_dict[IMAGE_KEY] = image\n",
        "  feat_dict[LABEL_KEY] = label\n",
        "  feat_dict[ATTR_KEY][GROUP_KEY] = group\n",
        "  return feat_dict"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G5ugtuS8FGVf"
      },
      "outputs": [],
      "source": [
        "get_image_label_and_group = lambda feat_dict: (feat_dict[IMAGE_KEY], feat_dict[LABEL_KEY], feat_dict[ATTR_KEY][GROUP_KEY])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bxmLaRahRgMg"
      },
      "outputs": [],
      "source": [
        "train_data = celeb_a_builder.as_dataset(split='train').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)\n",
        "train_iterator = train_data.as_numpy_iterator()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q7GVTcdbK9xp"
      },
      "outputs": [],
      "source": [
        "test_data = celeb_a_builder.as_dataset(split='test').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)\n",
        "test_iterator = test_data.as_numpy_iterator()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DGD-SRBgTY3q"
      },
      "outputs": [],
      "source": [
        "def get_data(iterator, IMAGE_SIZE):\n",
        "  Z, C, G  = [], [], []\n",
        "  for v in iterator:\n",
        "    Z.append(v[0].reshape(IMAGE_SIZE, IMAGE_SIZE, 3))\n",
        "    C.append(v[1]['Identity_No'][0])\n",
        "    G.append(v[2][0])\n",
        "  return np.array(Z), np.array(C), np.array(G)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BzuXwMN5MA7U"
      },
      "outputs": [],
      "source": [
        "Z_train, C_train, G_train = get_data(train_iterator, IMAGE_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U-Rts8-zMAwS"
      },
      "outputs": [],
      "source": [
        "Z_test, C_test, G_test = get_data(test_iterator, IMAGE_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WIEN09fxfKYE"
      },
      "outputs": [],
      "source": [
        "# remove classes with too little images from train\n",
        "\n",
        "min_images = 3\n",
        "\n",
        "def filter_rare(Z, C, G, min_images):\n",
        "  unq_C = np.unique(C)\n",
        "  img_cnts = np.array([np.sum(C==c) for c in unq_C])\n",
        "  classes_to_keep = unq_C[img_cnts >= min_images]\n",
        "  keep_idx = np.isin(C, classes_to_keep)\n",
        "  return Z[keep_idx], C[keep_idx], G[keep_idx]\n",
        "\n",
        "Z_train, C_train, G_train = filter_rare(Z_train, C_train, G_train, min_images)\n",
        "Z_test, C_test, G_test = filter_rare(Z_test, C_test, G_test, min_images)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Eu5MKQCjy7hJ"
      },
      "outputs": [],
      "source": [
        "def divide_by_grp(C, G, min_images, minor_labl):\n",
        "  unq_C = np.unique(C)\n",
        "  w_grp, wo_grp = [], []\n",
        "\n",
        "  for c in unq_C:\n",
        "    c_idx = np.where(C==c)[0]\n",
        "    if np.sum(G[c_idx]==minor_labl) >= min_images:\n",
        "      w_grp.append(c)\n",
        "    elif np.sum(G[c_idx]==1-minor_labl) >= min_images:\n",
        "      wo_grp.append(c)\n",
        "  return np.array(w_grp), np.array(wo_grp)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HH5p_3j8kbmm"
      },
      "outputs": [],
      "source": [
        "minor_labl = 1\n",
        "people_w_grp_images_train, people_wo_grp_images_train = divide_by_grp(C_train, G_train, min_images, minor_labl)\n",
        "people_w_grp_images_test, people_wo_grp_images_test = divide_by_grp(C_test, G_test, 2, minor_labl)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XaAfiHZ03WN3"
      },
      "outputs": [],
      "source": [
        "def select_instances(Z, C, G, unq_c_w_grp, unq_c_wo_grp, minor_labl):\n",
        "\n",
        "  # for people with grp keep only those images, for the rest keep only images not in grp\n",
        "  new_Z, new_C, new_G = [], [], []\n",
        "\n",
        "  for c in unq_c_w_grp:\n",
        "    c_idx = np.where(C==c)[0]\n",
        "    for j in c_idx:\n",
        "      if G[j]==minor_labl:\n",
        "        new_Z.append(Z[j])\n",
        "        new_C.append(c)\n",
        "        new_G.append(minor_labl)\n",
        "\n",
        "  for c in unq_c_wo_grp:\n",
        "    c_idx = np.where(C==c)[0]\n",
        "    for j in c_idx:\n",
        "      if G[j]==1-minor_labl:\n",
        "        new_Z.append(Z[j])\n",
        "        new_C.append(c)\n",
        "        new_G.append(1-minor_labl)\n",
        "\n",
        "  return(np.array(new_Z), np.array(new_C), np.array(new_G))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_H15dOXa-Q_t"
      },
      "outputs": [],
      "source": [
        "def select_classes(Z, C, G, classes_in_major_grp, classes_in_minor_grp, p_minor_grp, minor_labl, max_Nc=500):\n",
        "\n",
        "  # keep all classes in major grp\n",
        "  N_major = len(classes_in_major_grp)\n",
        "  N_minor = len(classes_in_minor_grp)\n",
        "\n",
        "  N_major = min(N_major, int((1-p_minor_grp)*N_minor/p_minor_grp))\n",
        "  N_minor = min(N_minor, int(p_minor_grp*N_major/(1-p_minor_grp)))\n",
        "\n",
        "  N_major = min(N_major, int((1-p_minor_grp)*max_Nc))\n",
        "  N_minor = min(N_minor, int(p_minor_grp*max_Nc))\n",
        "\n",
        "  Nc = N_major + N_minor\n",
        "\n",
        "  # select classes for minor grp for\n",
        "  unq_c_minor_grp = np.random.choice(classes_in_minor_grp, N_minor, replace=False)\n",
        "  # keep all those in major grp\n",
        "  unq_c_major_grp = np.random.choice(classes_in_major_grp, N_major, replace=False)\n",
        "\n",
        "  # for people for minor grp keep only those images, for the rest keep only images not in grp\n",
        "  new_Z, new_C, new_G = select_instances(Z, C, G, unq_c_minor_grp, unq_c_major_grp, minor_labl)\n",
        "\n",
        "  return new_Z, new_C, new_G"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xHlIjv7oKC3t"
      },
      "outputs": [],
      "source": [
        "p_minor = 0.05"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-sppISiQ_Gbf"
      },
      "outputs": [],
      "source": [
        "# in train mostly non-blonde people\n",
        "z_train, c_train, a_train = select_classes(Z_train, C_train, G_train, people_wo_grp_images_train, people_w_grp_images_train, p_minor, minor_labl)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FxWbVCZs_Gn4"
      },
      "outputs": [],
      "source": [
        "# in test mostly blonde people\n",
        "z_test, c_test, a_test = select_classes(Z_test, C_test, G_test, people_w_grp_images_test, people_wo_grp_images_test, p_minor, 1-minor_labl)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Be5axY8oFAU_"
      },
      "outputs": [],
      "source": [
        "# split train into train and validation\n",
        "\n",
        "unq_c_train = np.unique(c_train)\n",
        "\n",
        "unique_c_val = np.random.choice(unq_c_train, int(0.1*len(unq_c_train)), replace=False)\n",
        "unq_c_train = np.array([c for c in unq_c_train if c not in unique_c_val])\n",
        "\n",
        "val_bool = np.isin(c_train, unique_c_val)\n",
        "z_val, c_val, a_val = z_train[val_bool], c_train[val_bool], a_train[val_bool]\n",
        "\n",
        "train_bool = np.isin(c_train, unq_c_train)\n",
        "z_train, c_train, a_train = z_train[train_bool], c_train[train_bool], a_train[train_bool]"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pos_per_class_train, pos_per_class_test = 10, 10\n",
        "\n",
        "# generate pairs\n",
        "train_z1, train_z2, train_y, train_Cs = make_pairs(z_train, c_train, pos_per_class=pos_per_class_train)\n",
        "val_z1, val_z2, val_y, val_Cs = make_pairs(z_val, c_val, pos_per_class=pos_per_class_train)\n",
        "test_z1, test_z2, test_y, test_Cs = make_pairs(z_test, c_test, pos_per_class_test)"
      ],
      "metadata": {
        "id": "X0b59nAL1wEp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qCuslAQ3Qqe_"
      },
      "source": [
        "### Class sampling"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "classes_in_env = 2\n",
        "classes_in_env_test = 2\n",
        "\n",
        "n_sim_envs = 150"
      ],
      "metadata": {
        "id": "aTOUs84M3-cL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "n_envs =  10**6"
      ],
      "metadata": {
        "id": "TeSWVsF94L9l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_envs = []\n",
        "for i in range(n_envs):\n",
        "  e = np.random.choice(unq_c_train, classes_in_env, replace=False)\n",
        "  train_envs.append(e)\n",
        "\n",
        "train_envs = np.array(train_envs)"
      ],
      "metadata": {
        "id": "SJR4NUxM4NDH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_YjAeIHk0Jkz"
      },
      "source": [
        "### Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7b_BX7NXr7W8"
      },
      "outputs": [],
      "source": [
        "s1, s2, s3 = IMAGE_SIZE, IMAGE_SIZE, 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P7NrTu6fr1Eh"
      },
      "outputs": [],
      "source": [
        "def add_conv_block(model):\n",
        "  model.add(tf.keras.layers.Conv2D(filters=16, kernel_size=3, strides=1, padding=\"same\"))\n",
        "  model.add(tf.keras.layers.BatchNormalization())\n",
        "  model.add(tf.keras.layers.ReLU())\n",
        "  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K7C1W0_fIukF"
      },
      "outputs": [],
      "source": [
        "def init_representation():\n",
        "    model = tf.keras.Sequential()\n",
        "    model.add(tf.keras.layers.Input(shape=(s1, s2, s3)))\n",
        "    for i in range(2):\n",
        "      add_conv_block(model)\n",
        "    model.add(tf.keras.layers.Flatten())\n",
        "    model.add(tf.keras.layers.Dense(32))\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KpAkkDaUC4fU"
      },
      "source": [
        "Parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DHzkNNi6C3lX"
      },
      "outputs": [],
      "source": [
        "lr = 1e-3\n",
        "n_pairs = 8 * 10**6\n",
        "\n",
        "ERM_factor = 0.0\n",
        "CLoVE_factor = 0.085\n",
        "VarAUC_factor = 0.2\n",
        "VarREx_factor = 0.1\n",
        "\n",
        "IRM_factor = 0.01\n",
        "l2_regularizer_weight = tf.constant(0.01)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Initialization"
      ],
      "metadata": {
        "id": "4fog0RUo5Xq8"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KvfZUC_U9F1M"
      },
      "outputs": [],
      "source": [
        "init_g = init_representation()\n",
        "\n",
        "ERM_g = init_representation()\n",
        "IRM_g = init_representation()\n",
        "CLoVE_g = init_representation()\n",
        "VarREx_g = init_representation()\n",
        "VarAUC_g = init_representation()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W4-Of6yyhztG"
      },
      "source": [
        "#### ERM"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gLGQP09m2yGf"
      },
      "outputs": [],
      "source": [
        "ERM_g.set_weights(init_g.get_weights())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HBsrZoVOmuV-"
      },
      "outputs": [],
      "source": [
        "ERM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0dWaNkhDLfe4"
      },
      "outputs": [],
      "source": [
        "ERM_g, ERM_losses, ERM_Ns, ERM_test_aucs, ERM_val_aucs = training(ERM_optimizer, ERM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n",
        "                                                                  test_z1, test_z2, test_y, n_pairs, pos_per_class_train, ERM_factor, n_sim_envs, penalty_type=None)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YRDiGqmphwCL"
      },
      "outputs": [],
      "source": [
        "ERM_auc_train = evaluate(ERM_g, train_z1, train_z2, train_y)\n",
        "ERM_auc_val = evaluate(ERM_g, val_z1, val_z2, val_y)\n",
        "ERM_auc_test = evaluate(ERM_g, test_z1, test_z2, test_y)\n",
        "\n",
        "ERM_auc_train, ERM_auc_val, ERM_auc_test\n",
        "print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(ERM_auc_train, ERM_auc_val, ERM_auc_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "IRM"
      ],
      "metadata": {
        "id": "siEEhfuT50ku"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "IRM_g.set_weights(init_g.get_weights())"
      ],
      "metadata": {
        "id": "Khp57tll5yao"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "IRM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"
      ],
      "metadata": {
        "id": "ZCtj2xtE57ua"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "IRM_g, IRM_losses, IRM_Ns, IRM_test_aucs, IRM_val_aucs = training(IRM_optimizer, IRM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n",
        "                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, IRM_factor, n_sim_envs,\n",
        "                                                                                 penalty_type='IRM')"
      ],
      "metadata": {
        "id": "VYvlV69958zJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "IRM_auc_train = evaluate(IRM_g, train_z1, train_z2, train_y)\n",
        "IRM_auc_val = evaluate(IRM_g, val_z1, val_z2, val_y)\n",
        "IRM_auc_test = evaluate(IRM_g, test_z1, test_z2, test_y)\n",
        "\n",
        "print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(IRM_auc_train, IRM_auc_val, IRM_auc_test))"
      ],
      "metadata": {
        "id": "cMHqiFqV5-BH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "CLoVE"
      ],
      "metadata": {
        "id": "bBCAi9hs6Bzv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "CLoVE_g.set_weights(init_g.get_weights())"
      ],
      "metadata": {
        "id": "wGx3DoaD5_IX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "CLoVE_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"
      ],
      "metadata": {
        "id": "aCHAN62D6C1S"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "CLoVE_g, CLoVE_losses, CLoVE_Ns, CLoVE_test_aucs, CLoVE_val_aucs = training(CLoVE_optimizer, CLoVE_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n",
        "                                                                            test_z1, test_z2, test_y, n_pairs, pos_per_class_train, CLoVE_factor, n_sim_envs,\n",
        "                                                                            penalty_type='CLoVE')"
      ],
      "metadata": {
        "id": "6YBrMTd26Gjj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "CLoVE_auc_train = evaluate(CLoVE_g, train_z1, train_z2, train_y)\n",
        "CLoVE_auc_val = evaluate(CLoVE_g, val_z1, val_z2, val_y)\n",
        "CLoVE_auc_test = evaluate(CLoVE_g, test_z1, test_z2, test_y)\n",
        "\n",
        "print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(CLoVE_auc_train, CLoVE_auc_val, CLoVE_auc_test))"
      ],
      "metadata": {
        "id": "RZj-sWpA6Hi5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "VarREx"
      ],
      "metadata": {
        "id": "PcjMvsrh6J0S"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "VarREx_g.set_weights(init_g.get_weights())"
      ],
      "metadata": {
        "id": "Q0iA2A7l6JM1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "VarREx_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"
      ],
      "metadata": {
        "id": "FqFP-knM6Lv3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "VarREx_g, VarREx_losses, VarREx_Ns, VarREx_test_aucs, VarREx_val_aucs = training(VarREx_optimizer, VarREx_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n",
        "                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarREx_factor, n_sim_envs,\n",
        "                                                                                 penalty_type='VarREx')"
      ],
      "metadata": {
        "id": "jRBaIEyx6QbL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "VarREx_auc_train = evaluate(VarREx_g, train_z1, train_z2, train_y)\n",
        "VarREx_auc_val = evaluate(VarREx_g, val_z1, val_z2, val_y)\n",
        "VarREx_auc_test = evaluate(VarREx_g, test_z1, test_z2, test_y)\n",
        "\n",
        "print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(VarREx_auc_train, VarREx_auc_val, VarREx_auc_test))"
      ],
      "metadata": {
        "id": "hIxJMKdk6Sfz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "VarAUC"
      ],
      "metadata": {
        "id": "F1Yxo3c06UpQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "VarAUC_g.set_weights(init_g.get_weights())"
      ],
      "metadata": {
        "id": "PNaW4zER6TxM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "VarAUC_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"
      ],
      "metadata": {
        "id": "7iHKsU9j6Vlw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "VarAUC_g, VarAUC_losses, VarAUC_Ns, VarAUC_test_aucs, VarAUC_val_aucs = training(VarAUC_optimizer, VarAUC_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n",
        "                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarAUC_factor, n_sim_envs,\n",
        "                                                                                 penalty_type='VarAUC')"
      ],
      "metadata": {
        "id": "glTdxt4A6ZU9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "VarAUC_auc_train = evaluate(VarAUC_g, train_z1, train_z2, train_y)\n",
        "VarAUC_auc_val = evaluate(VarAUC_g, val_z1, val_z2, val_y)\n",
        "VarAUC_auc_test = evaluate(VarAUC_g, test_z1, test_z2, test_y)\n",
        "\n",
        "print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(VarAUC_auc_train, VarAUC_auc_val, VarAUC_auc_test))"
      ],
      "metadata": {
        "id": "hU1wIo-M6aYE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Comparing results"
      ],
      "metadata": {
        "id": "qg8ktVt86doo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "plt.plot(ERM_Ns, ERM_val_aucs, '--', color='C3')\n",
        "plt.plot(ERM_Ns, ERM_test_aucs,  label='ERM', color='C3')\n",
        "\n",
        "plt.plot(IRM_Ns, IRM_val_aucs, '--', color='C1', alpha=0.8)\n",
        "plt.plot(IRM_Ns, IRM_test_aucs, label='IRM ', color='C1', alpha=0.8)\n",
        "\n",
        "plt.plot(CLoVE_Ns, CLoVE_val_aucs, '--', color='C0', alpha=0.8)\n",
        "plt.plot(CLoVE_Ns, CLoVE_test_aucs, label='CLOvE', color='C0', alpha=0.8)\n",
        "\n",
        "plt.plot(VarREx_Ns, VarREx_val_aucs, '--', color='C4', alpha=0.8)\n",
        "plt.plot(VarREx_Ns, VarREx_test_aucs, label='VarREx', color='C4', alpha=0.8)\n",
        "\n",
        "plt.plot(VarAUC_Ns, VarAUC_val_aucs, '--',  color='C2')\n",
        "plt.plot(VarAUC_Ns, VarAUC_test_aucs, label='VarAUC', color='C2')\n",
        "\n",
        "\n",
        "plt.xlabel('Data Points (pairs)')\n",
        "plt.ylabel('AUC')\n",
        "plt.legend(loc=4);"
      ],
      "metadata": {
        "id": "4UPy4gRp6gWa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# split trainin data by attribute\n",
        "z_a1, c_a1 = z_train[a_train==1], c_train[a_train==1]\n",
        "z_a0, c_a0 = z_train[a_train==0], c_train[a_train==0]\n",
        "\n",
        "# make pairs\n",
        "z1_a1, z2_a1, y_a1, c_a1 = make_pairs(z_a1, c_a1, pos_per_class_train)\n",
        "z1_a0, z2_a0, y_a0, c_a0 = make_pairs(z_a0, c_a0, pos_per_class_train)\n",
        "\n",
        "# ERM representations\n",
        "z1_hat_a1_ERM, z2_hat_a1_ERM = ERM_g(z1_a1), ERM_g(z2_a1)\n",
        "z1_hat_a0_ERM, z2_hat_a0_ERM = ERM_g(z1_a0), ERM_g(z2_a0)\n",
        "\n",
        "# VarAUC representations\n",
        "z1_hat_a1_VarAUC, z2_hat_a1_VarAUC = VarAUC_g(z1_a1), VarAUC_g(z2_a1)\n",
        "z1_hat_a0_VarAUC, z2_hat_a0_VarAUC = VarAUC_g(z1_a0), VarAUC_g(z2_a0)\n",
        "\n",
        "# unpenalized losses\n",
        "def raw_loss(z1_hat, z2_hat, y_true, margin=0.5):\n",
        "  dist = cosine_distance(z1_hat, z2_hat)\n",
        "  l = tfa.losses.contrastive_loss(y_true, dist, margin)\n",
        "  return l, dist\n",
        "\n",
        "# on a1\n",
        "base_loss_a1_ERM, dist_a1_ERM = raw_loss(z1_hat_a1_ERM, z2_hat_a1_ERM, y_a1)\n",
        "base_loss_a1_VarAUC, dist_a1_VarAUC = raw_loss(z1_hat_a1_VarAUC, z2_hat_a1_VarAUC, y_a1)\n",
        "dist_a1_ERM, dist_a1_VarAUC = dist_a1_ERM.numpy(), dist_a1_VarAUC.numpy()\n",
        "\n",
        "# on a0\n",
        "base_loss_a0_ERM, dist_a0_ERM = raw_loss(z1_hat_a0_ERM, z2_hat_a0_ERM, y_a0)\n",
        "base_loss_a0_VarAUC, dist_a0_VarAUC = raw_loss(z1_hat_a0_VarAUC, z2_hat_a0_VarAUC, y_a0)\n",
        "dist_a0_ERM, dist_a0_VarAUC = dist_a0_ERM.numpy(), dist_a0_VarAUC.numpy()"
      ],
      "metadata": {
        "id": "FwFEZNxY6is0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_bins(diff, n_bins=19):\n",
        "  w = np.ptp(diff)/n_bins\n",
        "  u = np.ceil((max(diff)-w/2)/w)*w + w/2\n",
        "  l = np.ceil((abs(min(diff))-w/2)/w)*w + w/2\n",
        "  return np.linspace(-l, u, n_bins+2)"
      ],
      "metadata": {
        "id": "ubD3sWLH6lC8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig, axs = plt.subplots(2,2, figsize=(8,8))\n",
        "\n",
        "t1 = len(base_loss_a1_ERM[y_a1==0])\n",
        "t0 = len(base_loss_a0_ERM[y_a0==0])\n",
        "\n",
        "n_bins=19\n",
        "\n",
        "diff_00 = base_loss_a1_ERM[y_a1==0] - base_loss_a1_VarAUC[y_a1==0]\n",
        "bins_00 = get_bins(diff_00, n_bins)\n",
        "axs[0,0].hist(diff_00, bins=bins_00, alpha=0.5, weights=(np.ones(t1)/t1), color='C5', ec='C5')\n",
        "axs[0,0].axvline(0, c='k', linestyle=':')\n",
        "axs[0,0].set_title('Nymphalidae (minority in training)', fontsize=11)\n",
        "axs[0,0].set_ylabel(r'$y=0$', fontsize=11)\n",
        "xl = max(abs(bins_00))*1.1\n",
        "axs[0,0].set_xlim(-xl, xl)\n",
        "\n",
        "diff_01 = base_loss_a0_ERM[y_a0==0] - base_loss_a0_VarAUC[y_a0==0]\n",
        "bins_01 = get_bins(diff_01, n_bins)\n",
        "axs[0,1].hist(diff_01, bins=bins_01, alpha=0.5, weights=(np.ones(t0)/t0), color='C5', ec='C5')\n",
        "axs[0,1].axvline(0, c='k', linestyle=':')\n",
        "axs[0,1].set_title('Lycaenidae (majority in training)', fontsize=11)\n",
        "xl = max(abs(bins_01))*1.1\n",
        "axs[0,1].set_xlim(-xl, xl)\n",
        "\n",
        "t1 = len(base_loss_a1_ERM[y_a1==1])\n",
        "t0 = len(base_loss_a0_ERM[y_a0==1])\n",
        "\n",
        "diff_10 = base_loss_a1_ERM[y_a1==1] - base_loss_a1_VarAUC[y_a1==1]\n",
        "bins_10 = get_bins(diff_10, n_bins)\n",
        "axs[1,0].hist(diff_10, bins=bins_10, alpha=0.5, weights=(np.ones(t1)/t1), color='C7', ec='C7')\n",
        "axs[1,0].axvline(0, c='k', linestyle=':')\n",
        "axs[1,0].set_ylabel(r'$y=1$', fontsize=11)\n",
        "xl = max(abs(bins_10))*1.1\n",
        "axs[1,0].set_xlim(-xl, xl)\n",
        "\n",
        "diff_11 = base_loss_a0_ERM[y_a0==1] - base_loss_a0_VarAUC[y_a0==1]\n",
        "bins_11 = get_bins(diff_11, n_bins)\n",
        "axs[1,1].hist(diff_11, bins=bins_11, alpha=0.5, weights=(np.ones(t0)/t0), color='C7', ec='C7')\n",
        "axs[1,1].axvline(0, c='k', linestyle=':')\n",
        "xl = max(abs(bins_11))*1.1\n",
        "axs[1,1].set_xlim(-xl, xl);"
      ],
      "metadata": {
        "id": "kiwD5rKL6vKm"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}