{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"uGMmryQKaVkp"},"outputs":[],"source":["import os\n","import cv2\n","import numpy as np\n","import pandas as pd\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","import tensorflow_addons as tfa\n","from sklearn.metrics import roc_curve, auc\n","import matplotlib.pyplot as plt\n","%matplotlib inline"]},{"cell_type":"markdown","source":["This notebook assumes that the ETHEC dataset (zip file) was downloaded and placed in the same folder as the code. The dataset can be downloaded from: https://www.research-collection.ethz.ch/handle/20.500.11850/365379"],"metadata":{"id":"u4QU0J7v966f"}},{"cell_type":"code","source":["path = os.getcwd()\n","\n","# Check if running in Colab\n","try:\n","  from google.colab import drive\n","  IN_COLAB=True\n","  print(\"Running in Colab\")\n","  # Mount Google Drive\n","  drive.mount('/content/drive', force_remount=True)\n","  # Change directory\n","  %cd /content/drive/MyDrive/Class_Distribution_Shifts_in_Zero_Shot_Learning_Learning_Robust_Representations\n","except:\n","  IN_COLAB=False\n","  print(\"Running locally\")"],"metadata":{"id":"6QMj_MOn9yVM"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"sn7vSEKWKl5d"},"outputs":[],"source":["from pairs import distinct_pairs_func, make_pairs\n","from algorithm import *\n","from synthetic_data import *"]},{"cell_type":"markdown","metadata":{"id":"0wGVBdB_QlCM"},"source":["### Load data"]},{"cell_type":"code","source":["! unzip \"ETHEC_v02\""],"metadata":{"id":"0tvk0d9DkepH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# get meta data from files\n","\n","def extract_details(path):\n","  df = pd.read_json(path)\n","  df = df.transpose()\n","  df = df[['image_path', 'image_name', 'family', 'subfamily', 'genus', 'specific_epithet']]\n","  return df\n","\n","train_df = extract_details('ETHEC_dataset/splits/train.json')\n","test_df = extract_details('ETHEC_dataset/splits/test.json')\n","val_df = extract_details('ETHEC_dataset/splits/val.json')"],"metadata":{"id":"513MZ9jdwBUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# focus on Lycaenidae and Nymphalidae families\n","chosen_families = ['Lycaenidae', 'Nymphalidae']\n","minority = 'Lycaenidae'\n","\n","# subset chosen families\n","train_df = train_df[train_df['family'].isin(chosen_families)]\n","test_df = test_df[test_df['family'].isin(chosen_families)]\n","val_df = val_df[val_df['family'].isin(chosen_families)]\n","df = pd.concat([train_df, test_df, val_df], ignore_index=True, axis=0)\n","\n","minority_classes = (df[df['family']==minority]['specific_epithet']).unique()\n","majority_classes = (df[df['family']!=minority]['specific_epithet']).unique()"],"metadata":{"id":"oBEDDfkIyFF3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["C = df['specific_epithet']\n","Nc = len(minority_classes) + len(majority_classes)\n","\n","p_minor = 0.1 # proportion of minority group\n","test_p = 0.2 # proportion of classes to save for test (under distribution shifyt)\n","\n","N_minor_test = int(len(minority_classes) * test_p)\n","N_major_test = int(N_minor_test * (p_minor)/(1-p_minor))"],"metadata":{"id":"gEfaur4755DV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# select test classes (distribution shift scenario)\n","\n","test_minority_classes = np.random.choice(minority_classes, N_minor_test, replace=False)\n","test_majority_classes = np.random.choice(majority_classes, N_major_test, replace=False)\n","unq_c_test = np.hstack([test_minority_classes, test_majority_classes])"],"metadata":{"id":"qo-XriAfpULb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["assert np.mean(np.isin(test_minority_classes, minority_classes))==1"],"metadata":{"id":"H9IU-ovlefP5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(\"minority proportion in training: {:.4f}, majority proportion in training: {:.4f}\".format(np.mean(np.isin(unq_c_test, test_minority_classes)), np.mean(np.isin(unq_c_test, test_majority_classes))))"],"metadata":{"id":"4POuNgaALz0t"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["remaining_minorty_classes = np.array([c for c in minority_classes if c not in test_minority_classes])\n","remaining_majority_classes = np.array([c for c in majority_classes if c not in test_majority_classes])"],"metadata":{"id":"z1f5tVPhQmG1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# select train classes\n","\n","N_minor_train = min(int(len(remaining_majority_classes) * (p_minor)/(1-p_minor)), len(remaining_minorty_classes))\n","N_major_train = min(len(remaining_majority_classes), int(N_minor_train * (1 - p_minor)/p_minor))\n","\n","train_minority_classes = np.random.choice(minority_classes, N_minor_train, replace=False)\n","train_majority_classes = np.random.choice(majority_classes, N_major_train, replace=False)\n","unq_c_train = np.hstack([train_minority_classes, train_majority_classes])"],"metadata":{"id":"BYE9CKMxnry5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["assert np.mean(np.isin(train_minority_classes, minority_classes))==1"],"metadata":{"id":"vnYG9c_2m2fq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(\"minority proportion in training: {:.4f}, majority proportion in training: {:.4f}\".format(np.mean(np.isin(unq_c_train, train_minority_classes)), np.mean(np.isin(unq_c_train, train_majority_classes))))"],"metadata":{"id":"_OeLeQukQlyM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# set apart validation data (in-distrinution setting)\n","\n","unq_c_val = np.random.choice(unq_c_train, int(0.4*len(unq_c_train)), replace=False)\n","unq_c_train = np.array([c for c in unq_c_train if c not in unq_c_val])\n","Nc = len(unq_c_val) + len(unq_c_train) + len(unq_c_test)"],"metadata":{"id":"fRJlcNp0Qlq3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_file_names(df, c_list, r = 10):\n","  all_files = []\n","  for c in c_list:\n","    paths = list(df.image_path[df['specific_epithet'].isin(c_list)])\n","    names = list(df.image_name[df['specific_epithet'].isin(c_list)])\n","    files = [os.path.join('ETHEC_dataset', 'IMAGO_build_test_resized', paths[i], names[i]) for i in range(len(paths))]\n","    all_files.append(np.random.choice(files, r, replace=False))\n","  return files\n","\n","train_files, val_files, test_files = get_file_names(df, unq_c_train), get_file_names(df, unq_c_val), get_file_names(df, unq_c_test)"],"metadata":{"id":"VtfbZdh1W9fr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# extract images\n","\n","SIZE = (100,100)\n","\n","def load_img(path, size):\n","  img = cv2.imread(path)\n","  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n","  img = cv2.resize(img, size)\n","  return np.array(img)\n","\n","# data\n","z_train = np.array([load_img(file, SIZE) for file in train_files]) / 255.0\n","z_test = np.array([load_img(file, SIZE) for file in test_files]) / 255.0\n","z_val = np.array([load_img(file, SIZE) for file in val_files]) / 255.0"],"metadata":{"id":"5UuVZFn90dvw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# class names\n","c_train_name = list(df.specific_epithet[df['specific_epithet'].isin(unq_c_train)])\n","c_test_name = list(df.specific_epithet[df['specific_epithet'].isin(unq_c_test)])\n","c_val_name = list(df.specific_epithet[df['specific_epithet'].isin(unq_c_val)])\n","\n","# attributes\n","A_train_name = [list(df.family[df.specific_epithet==c])[0] for c in c_train_name]\n","A_test_name = [list(df.family[df.specific_epithet==c])[0] for c in c_test_name]\n","A_val_name = [list(df.family[df.specific_epithet==c])[0] for c in c_val_name]\n","\n","attribute_dict = {'Nymphalidae': 0, 'Lycaenidae':1}\n","A_train = np.array([attribute_dict[name] for name in A_train_name])\n","A_test = np.array([attribute_dict[name] for name in A_test_name])\n","A_val = np.array([attribute_dict[name] for name in A_val_name])\n","\n","# classes\n","unq_c = np.unique(np.hstack([c_train_name, c_test_name, c_val_name]))\n","class_dict = {}\n","for i, c in enumerate(unq_c):\n","  class_dict[c] = i\n","\n","c_train, c_test, c_val = np.array([class_dict[c] for c in c_train_name]), np.array([class_dict[c] for c in c_test_name]), np.array([class_dict[c] for c in c_val_name])\n","unq_c_train, unq_c_test, unq_c_val = np.unique(c_train), np.unique(c_test), np.unique(c_val)"],"metadata":{"id":"ox_ArPE_fAbB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# keep species with at least 5 images\n","\n","def filter_images(z, c_list, selected_c, names_list, A, min_images):\n","  c_idx = np.array([np.isin(c_list, c) for c in selected_c])\n","  selected_c = selected_c[np.sum(c_idx, axis=1) >= min_images]\n","  idx = np.array([np.random.choice(np.where(np.isin(c_list, c))[0], min_images, replace=False) for c in selected_c]).flatten()\n","  return c_list[idx], z[idx], np.array(names_list)[idx], A[idx]\n","\n","min_images = 5\n","\n","c_train, z_train, c_train_name, A_train = filter_images(z_train, c_train, unq_c_train, c_train_name, A_train, min_images)\n","c_val, z_val, c_val_name, A_val = filter_images(z_val, c_val, unq_c_val, c_val_name, A_val, min_images)\n","c_test, z_test, c_test_name, A_test = filter_images(z_test, c_test, unq_c_test, c_test_name, A_test, min_images)\n","\n","unq_c_train, unq_c_test, unq_c_val = np.unique(c_train), np.unique(c_test), np.unique(c_val)\n","unq_c = np.unique(np.hstack([c_train, c_test, c_val]))"],"metadata":{"id":"TxCab4y1Hf-_"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yIdyHd9oHcYx"},"outputs":[],"source":["pos_per_class_train, pos_per_class_test = 10, 10\n","\n","# generate pairs\n","train_z1, train_z2, train_y, train_Cs = make_pairs(z_train, c_train, pos_per_class=pos_per_class_train)\n","val_z1, val_z2, val_y, val_Cs = make_pairs(z_val, c_val, pos_per_class=pos_per_class_train)\n","test_z1, test_z2, test_y, test_Cs = make_pairs(z_test, c_test, pos_per_class_test)"]},{"cell_type":"markdown","metadata":{"id":"8KK2D2BBrdRX"},"source":["### Class sampling"]},{"cell_type":"code","source":["classes_in_env = 2\n","classes_in_env_test = 2\n","\n","n_sim_envs = 200"],"metadata":{"id":"tz-VDxc4XkAx"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bKEWRjdy2jDg"},"outputs":[],"source":["n_envs =  10**6"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"e8BYwkqoV_1v"},"outputs":[],"source":["train_envs = []\n","for i in range(n_envs):\n","  e = np.random.choice(unq_c_train, classes_in_env, replace=False)\n","  train_envs.append(e)\n","\n","train_envs = np.array(train_envs)"]},{"cell_type":"markdown","metadata":{"id":"ePNaCa4KFc8-"},"source":["### Models"]},{"cell_type":"code","source":["s1, s2, s3 = SIZE[0], SIZE[1], 3"],"metadata":{"id":"7b_BX7NXr7W8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def init_representation():\n","    model = tf.keras.Sequential()\n","    model.add(tf.keras.layers.Input(shape=(s1, s2, s3)))\n","    model.add(tf.keras.layers.Flatten())\n","    model.add(tf.keras.layers.Dense(128, activation='relu'))\n","    model.add(tf.keras.layers.Dense(64, activation='relu'))\n","    model.add(tf.keras.layers.Dense(32, activation='relu'))\n","    model.add(tf.keras.layers.Dense(16))\n","    return model"],"metadata":{"id":"DbE58vUJ8ocg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"KpAkkDaUC4fU"},"source":["Parameters"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DHzkNNi6C3lX"},"outputs":[],"source":["lr = 1e-5\n","n_pairs = 10**5\n","\n","ERM_factor = 0.0\n","CLoVE_factor = 0.05\n","VarAUC_factor = 0.2\n","VarREx_factor = 0.1\n","\n","IRM_factor = 0.02\n","l2_regularizer_weight = tf.constant(0.01)"]},{"cell_type":"markdown","source":["Initialization"],"metadata":{"id":"lifUb0bnZD33"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"KvfZUC_U9F1M"},"outputs":[],"source":["init_g = init_representation()\n","\n","ERM_g = init_representation()\n","IRM_g = init_representation()\n","CLoVE_g = init_representation()\n","VarREx_g = init_representation()\n","VarAUC_g = init_representation()"]},{"cell_type":"markdown","source":["ERM"],"metadata":{"id":"wB1NUEZkZ-wv"}},{"cell_type":"code","source":["ERM_g.set_weights(init_g.get_weights())"],"metadata":{"id":"44L8MGNhZs3-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ERM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"fI9XPbypaLrn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ERM_g, ERM_losses, ERM_Ns, ERM_test_aucs, ERM_val_aucs = training(ERM_optimizer, ERM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                  test_z1, test_z2, test_y, n_pairs, pos_per_class_train, ERM_factor, n_sim_envs, penalty_type=None)"],"metadata":{"id":"AygP_T8LaLxu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ERM_auc_train = evaluate(ERM_g, train_z1, train_z2, train_y)\n","ERM_auc_val = evaluate(ERM_g, val_z1, val_z2, val_y)\n","ERM_auc_test = evaluate(ERM_g, test_z1, test_z2, test_y)\n","\n","ERM_auc_train, ERM_auc_val, ERM_auc_test\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(ERM_auc_train, ERM_auc_val, ERM_auc_test))"],"metadata":{"id":"CjMfWMnnaL21"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["IRM"],"metadata":{"id":"iFN_zbo7Z-9w"}},{"cell_type":"code","source":["IRM_g.set_weights(init_g.get_weights())"],"metadata":{"id":"B0jdacqOa4rU"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["IRM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"3BEUOiBza4zy"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["IRM_g, IRM_losses, IRM_Ns, IRM_test_aucs, IRM_val_aucs = training(IRM_optimizer, IRM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, IRM_factor, n_sim_envs,\n","                                                                                 penalty_type='IRM')"],"metadata":{"id":"hCTblXWBa45j"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["IRM_auc_train = evaluate(IRM_g, train_z1, train_z2, train_y)\n","IRM_auc_val = evaluate(IRM_g, val_z1, val_z2, val_y)\n","IRM_auc_test = evaluate(IRM_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(IRM_auc_train, IRM_auc_val, IRM_auc_test))"],"metadata":{"id":"fmrSgXw3a4-l"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["CLoVe"],"metadata":{"id":"XeZPVPWoZ_H_"}},{"cell_type":"code","source":["CLoVE_g.set_weights(init_g.get_weights())"],"metadata":{"id":"qAyTPFima__I"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["CLoVE_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"RW2gSFo6bAB4"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["CLoVE_g, CLoVE_losses, CLoVE_Ns, CLoVE_test_aucs, CLoVE_val_aucs = training(CLoVE_optimizer, CLoVE_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                            test_z1, test_z2, test_y, n_pairs, pos_per_class_train, CLoVE_factor, n_sim_envs,\n","                                                                            penalty_type='CLoVE')"],"metadata":{"id":"ycrIdajqbAFP"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["CLoVE_auc_train = evaluate(CLoVE_g, train_z1, train_z2, train_y)\n","CLoVE_auc_val = evaluate(CLoVE_g, val_z1, val_z2, val_y)\n","CLoVE_auc_test = evaluate(CLoVE_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(CLoVE_auc_train, CLoVE_auc_val, CLoVE_auc_test))"],"metadata":{"id":"fmCJHaDxbAJZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["VarREx"],"metadata":{"id":"lzWnqd5QZ_RQ"}},{"cell_type":"code","source":["VarREx_g.set_weights(init_g.get_weights())"],"metadata":{"id":"HCHX-ooZbOlp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarREx_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"SbJwFOHUbOoq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarREx_g, VarREx_losses, VarREx_Ns, VarREx_test_aucs, VarREx_val_aucs = training(VarREx_optimizer, VarREx_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarREx_factor, n_sim_envs,\n","                                                                                 penalty_type='VarREx')"],"metadata":{"id":"hryodTy6bOti"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarREx_auc_train = evaluate(VarREx_g, train_z1, train_z2, train_y)\n","VarREx_auc_val = evaluate(VarREx_g, val_z1, val_z2, val_y)\n","VarREx_auc_test = evaluate(VarREx_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(VarREx_auc_train, VarREx_auc_val, VarREx_auc_test))"],"metadata":{"id":"Y0oh2Wl7bOxb"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["VarAUC"],"metadata":{"id":"ZTsJzh8BZ_cB"}},{"cell_type":"code","source":["VarAUC_g.set_weights(init_g.get_weights())"],"metadata":{"id":"_Tuau0KLbZAz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarAUC_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"ZEbSWm6LbZDP"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarAUC_g, VarAUC_losses, VarAUC_Ns, VarAUC_test_aucs, VarAUC_val_aucs = training(VarAUC_optimizer, VarAUC_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarAUC_factor, n_sim_envs,\n","                                                                                 penalty_type='VarAUC')"],"metadata":{"id":"PItcwscKZsuF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarAUC_auc_train = evaluate(VarAUC_g, train_z1, train_z2, train_y)\n","VarAUC_auc_val = evaluate(VarAUC_g, val_z1, val_z2, val_y)\n","VarAUC_auc_test = evaluate(VarAUC_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(VarAUC_auc_train, VarAUC_auc_val, VarAUC_auc_test))"],"metadata":{"id":"n9eX9QccZsln"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kaOsmCdFOk8r"},"source":["### Comparing results"]},{"cell_type":"code","source":["plt.plot(ERM_Ns, ERM_val_aucs, '--', color='C3')\n","plt.plot(ERM_Ns, ERM_test_aucs,  label='ERM', color='C3')\n","\n","plt.plot(IRM_Ns, IRM_val_aucs, '--', color='C1', alpha=0.8)\n","plt.plot(IRM_Ns, IRM_test_aucs, label='IRM ', color='C1', alpha=0.8)\n","\n","plt.plot(CLoVE_Ns, CLoVE_val_aucs, '--', color='C0', alpha=0.8)\n","plt.plot(CLoVE_Ns, CLoVE_test_aucs, label='CLOvE', color='C0', alpha=0.8)\n","\n","plt.plot(VarREx_Ns, VarREx_val_aucs, '--', color='C4', alpha=0.8)\n","plt.plot(VarREx_Ns, VarREx_test_aucs, label='VarREx', color='C4', alpha=0.8)\n","\n","plt.plot(VarAUC_Ns, VarAUC_val_aucs, '--',  color='C2')\n","plt.plot(VarAUC_Ns, VarAUC_test_aucs, label='VarAUC', color='C2')\n","\n","\n","plt.xlabel('Data Points (pairs)')\n","plt.ylabel('AUC')\n","plt.legend(loc=4);"],"metadata":{"id":"6H26YR5kuajL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# split trainin data by attribute\n","z_a1, c_a1 = z_train[A_train==1], c_train[A_train==1]\n","z_a0, c_a0 = z_train[A_train==0], c_train[A_train==0]\n","\n","# make pairs\n","z1_a1, z2_a1, y_a1, c_a1 = make_pairs(z_a1, c_a1, pos_per_class_train)\n","z1_a0, z2_a0, y_a0, c_a0 = make_pairs(z_a0, c_a0, pos_per_class_train)\n","\n","# ERM representations\n","z1_hat_a1_ERM, z2_hat_a1_ERM = ERM_g(z1_a1), ERM_g(z2_a1)\n","z1_hat_a0_ERM, z2_hat_a0_ERM = ERM_g(z1_a0), ERM_g(z2_a0)\n","\n","# VarAUC representations\n","z1_hat_a1_VarAUC, z2_hat_a1_VarAUC = VarAUC_g(z1_a1), VarAUC_g(z2_a1)\n","z1_hat_a0_VarAUC, z2_hat_a0_VarAUC = VarAUC_g(z1_a0), VarAUC_g(z2_a0)\n","\n","# unpenalized losses\n","def raw_loss(z1_hat, z2_hat, y_true, margin=0.5):\n","  dist = cosine_distance(z1_hat, z2_hat)\n","  l = tfa.losses.contrastive_loss(y_true, dist, margin)\n","  return l, dist\n","\n","# on a1\n","base_loss_a1_ERM, dist_a1_ERM = raw_loss(z1_hat_a1_ERM, z2_hat_a1_ERM, y_a1)\n","base_loss_a1_VarAUC, dist_a1_VarAUC = raw_loss(z1_hat_a1_VarAUC, z2_hat_a1_VarAUC, y_a1)\n","dist_a1_ERM, dist_a1_VarAUC = dist_a1_ERM.numpy(), dist_a1_VarAUC.numpy()\n","\n","# on a0\n","base_loss_a0_ERM, dist_a0_ERM = raw_loss(z1_hat_a0_ERM, z2_hat_a0_ERM, y_a0)\n","base_loss_a0_VarAUC, dist_a0_VarAUC = raw_loss(z1_hat_a0_VarAUC, z2_hat_a0_VarAUC, y_a0)\n","dist_a0_ERM, dist_a0_VarAUC = dist_a0_ERM.numpy(), dist_a0_VarAUC.numpy()"],"metadata":{"id":"PtR8BnB72C4F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_bins(diff, n_bins=19):\n","  w = np.ptp(diff)/n_bins\n","  u = np.ceil((max(diff)-w/2)/w)*w + w/2\n","  l = np.ceil((abs(min(diff))-w/2)/w)*w + w/2\n","  return np.linspace(-l, u, n_bins+2)"],"metadata":{"id":"TGeF82VhSdMB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["fig, axs = plt.subplots(2,2, figsize=(8,8))\n","\n","t1 = len(base_loss_a1_ERM[y_a1==0])\n","t0 = len(base_loss_a0_ERM[y_a0==0])\n","\n","n_bins=19\n","\n","diff_00 = base_loss_a1_ERM[y_a1==0] - base_loss_a1_VarAUC[y_a1==0]\n","bins_00 = get_bins(diff_00, n_bins)\n","axs[0,0].hist(diff_00, bins=bins_00, alpha=0.5, weights=(np.ones(t1)/t1), color='C5', ec='C5')\n","axs[0,0].axvline(0, c='k', linestyle=':')\n","axs[0,0].set_title('Nymphalidae (minority in training)', fontsize=11)\n","axs[0,0].set_ylabel(r'$y=0$', fontsize=11)\n","xl = max(abs(bins_00))*1.1\n","axs[0,0].set_xlim(-xl, xl)\n","\n","diff_01 = base_loss_a0_ERM[y_a0==0] - base_loss_a0_VarAUC[y_a0==0]\n","bins_01 = get_bins(diff_01, n_bins)\n","axs[0,1].hist(diff_01, bins=bins_01, alpha=0.5, weights=(np.ones(t0)/t0), color='C5', ec='C5')\n","axs[0,1].axvline(0, c='k', linestyle=':')\n","axs[0,1].set_title('Lycaenidae (majority in training)', fontsize=11)\n","xl = max(abs(bins_01))*1.1\n","axs[0,1].set_xlim(-xl, xl)\n","\n","t1 = len(base_loss_a1_ERM[y_a1==1])\n","t0 = len(base_loss_a0_ERM[y_a0==1])\n","\n","diff_10 = base_loss_a1_ERM[y_a1==1] - base_loss_a1_VarAUC[y_a1==1]\n","bins_10 = get_bins(diff_10, n_bins)\n","axs[1,0].hist(diff_10, bins=bins_10, alpha=0.5, weights=(np.ones(t1)/t1), color='C7', ec='C7')\n","axs[1,0].axvline(0, c='k', linestyle=':')\n","axs[1,0].set_ylabel(r'$y=1$', fontsize=11)\n","xl = max(abs(bins_10))*1.1\n","axs[1,0].set_xlim(-xl, xl)\n","\n","diff_11 = base_loss_a0_ERM[y_a0==1] - base_loss_a0_VarAUC[y_a0==1]\n","bins_11 = get_bins(diff_11, n_bins)\n","axs[1,1].hist(diff_11, bins=bins_11, alpha=0.5, weights=(np.ones(t0)/t0), color='C7', ec='C7')\n","axs[1,1].axvline(0, c='k', linestyle=':')\n","xl = max(abs(bins_11))*1.1\n","axs[1,1].set_xlim(-xl, xl);"],"metadata":{"id":"4_GbWn4-KE8c"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"provenance":[{"file_id":"1iobAI4sQxrx1QM7Bir52JlcXQ6U8boRH","timestamp":1697920683664},{"file_id":"1QwGuTjlnE3M7PYdU-d-fJiAl4UAorhKD","timestamp":1696437637986},{"file_id":"1TDUCBHaEkD_Qx19_86FAorpdlWy91wRE","timestamp":1696416652983},{"file_id":"1sSyl87Kxpsd1eqWmva2xKT9L6t7EEEB7","timestamp":1696262491203},{"file_id":"1Yo2zIAh6D1-0nM7MOiDJQZKfJ9vqb3Or","timestamp":1696255402049},{"file_id":"16lKbBtyFaKmFk7RvwyzI_4wBwTGpF9-u","timestamp":1696172030568},{"file_id":"1mtQ7SZQU-svSsBO90LHAu-s7lxI0K_J0","timestamp":1696103554209},{"file_id":"15zpW4yGsGBVvRA7WReCdLXB8dJy4DW_7","timestamp":1696088714133},{"file_id":"1aCwJxhJazOxQl9EkUHBzaNEX9ePQdcA_","timestamp":1695981440686},{"file_id":"1d9FYUGBy2JZFpOb7kxLKLd0REUN8By2q","timestamp":1695914965719},{"file_id":"1OQR55crtvRtsCkICbdhJNBj4qnGbr6Me","timestamp":1695823964114},{"file_id":"1EiPLfNe8BIz5s3x_8ShUQZc0jC2iRw4H","timestamp":1695201530425},{"file_id":"1IkMrw6RJj6rsctSE6aXZwYMwIVBTUoGb","timestamp":1695120461090},{"file_id":"1app-Ikpz7ciNIY0Odzg7tFgNbCaAwHt8","timestamp":1694981579448},{"file_id":"1lcUrasMJAj89ULVmbyv0HN850cPZ1gv0","timestamp":1694941968925},{"file_id":"1mSYz-T4i6itmFUQy8uaHa5WKvXBgxPRn","timestamp":1694880923641},{"file_id":"1mOoW0BvZMNRum3Y30qrOkgohy9TqRkC7","timestamp":1694784592758},{"file_id":"1YorfIjlYD9GrKlfnbVtCH4oQQCxzfiwz","timestamp":1693906045347},{"file_id":"1vikBs4qVDYFssMEtbKIsN1BOcuEq4Tma","timestamp":1693828958876},{"file_id":"111lYqPXYzHyb7oXBP5PGg0og-u8_-Bqt","timestamp":1693234521949},{"file_id":"1-Woo69BgyMRS1_zTwvB_w2pbXCJMFwE0","timestamp":1692610678039},{"file_id":"1fuGUZSbL-7l-UoKiGI_ipuqJsl68z5ZH","timestamp":1683212300244}],"machine_shape":"hm","authorship_tag":"ABX9TyNePgE/zSuc/PRgGA/d+oMs"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}