{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"uGMmryQKaVkp"},"outputs":[],"source":["import os\n","import numpy as np\n","import tensorflow as tf\n","import tensorflow_addons as tfa\n","from tqdm.notebook import tqdm\n","from sklearn.metrics import roc_curve, auc\n","import matplotlib.pyplot as plt\n","%matplotlib inline"]},{"cell_type":"markdown","source":["The following assumes that this notebook is either used locally or in Colab within a folder named\n","Class_Distribution_Shifts_in_Zero_Shot_Learning_Learning_Robust_Representations\n","\n","\n"],"metadata":{"id":"LkH6K5U9kISC"}},{"cell_type":"code","source":["path = os.getcwd()\n","\n","# Check if running in Colab\n","try:\n","  from google.colab import drive\n","  IN_COLAB=True\n","  print(\"Running in Colab\")\n","  # Mount Google Drive\n","  drive.mount('/content/drive', force_remount=True)\n","  # Change directory\n","  %cd /content/drive/MyDrive/Class_Distribution_Shifts_in_Zero_Shot_Learning_Learning_Robust_Representations\n","except:\n","  IN_COLAB=False\n","  print(\"Running locally\")"],"metadata":{"id":"cbkAZvsbZ8L9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from pairs import distinct_pairs_func, make_pairs\n","from algorithm import *\n","from synthetic_data import *"],"metadata":{"id":"7E2jBgVXZilf"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0wGVBdB_QlCM"},"source":["### Generate Data"]},{"cell_type":"code","source":["v0_dim = 5\n","vminus_dim = 10\n","vplus_dim = 10\n","noise_dim = 25\n","\n","v0 = 1.0\n","vminus = 0.1\n","vplus = 2.0\n","\n","p_minor = 0.1\n","Nc = 500\n","r = 30\n","\n","vz = 1.0\n","vz_noise = 10.0"],"metadata":{"id":"cP-T8r3p5mcz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["signal_dim = v0_dim + vminus_dim + vplus_dim\n","total_dim = signal_dim + noise_dim"],"metadata":{"id":"sALbpLq-mI7A"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["z_train, c_train, z_val, c_val, z_test, c_test  = generate_synthetic_data(Nc, r, v0, vminus, vplus, vz, vz_noise, v0_dim, vminus_dim, vplus_dim, noise_dim, p_minor)"],"metadata":{"id":"7JNbJt63038Y"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1uhPtXFdb75g"},"outputs":[],"source":["pos_per_class_train, pos_per_class_test = 5, 5\n","\n","# generate pairs\n","train_z1, train_z2, train_y, train_Cs = make_pairs(z_train, c_train, pos_per_class_train)\n","val_z1, val_z2, val_y, val_Cs = make_pairs(z_val, c_val, pos_per_class_train)\n","test_z1, test_z2, test_y, test_Cs = make_pairs(z_test, c_test, pos_per_class_test)"]},{"cell_type":"markdown","source":["### Class sampling"],"metadata":{"id":"SDbqj5huocsO"}},{"cell_type":"code","source":["classes_in_env = 2\n","classes_in_env_test = 2\n","\n","n_sim_envs = int(np.log(0.5)/np.log(1 - p_minor**2))"],"metadata":{"id":"w1evNukLoQ94"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bKEWRjdy2jDg"},"outputs":[],"source":["n_envs = 10**5"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"e8BYwkqoV_1v"},"outputs":[],"source":["unq_c_train = np.unique(c_train)\n","train_envs = []\n","\n","for i in range(n_envs):\n","  e = np.random.choice(unq_c_train, classes_in_env, replace=False)\n","  train_envs.append(e)\n","train_envs = np.array(train_envs)"]},{"cell_type":"markdown","metadata":{"id":"ePNaCa4KFc8-"},"source":["### Models"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"K7C1W0_fIukF"},"outputs":[],"source":["def init_representation(add_dropout=False):\n","    model = tf.keras.Sequential()\n","    model.add(tf.keras.layers.Input(shape=(total_dim)))\n","    model.add(tf.keras.layers.Dense(16))\n","    return model"]},{"cell_type":"markdown","metadata":{"id":"KpAkkDaUC4fU"},"source":["Parameters"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DHzkNNi6C3lX"},"outputs":[],"source":["lr = 0.01\n","n_pairs = 7 * 10**5\n","\n","ERM_factor = 0.0\n","CLoVE_factor = 0.085\n","VarAUC_factor = 1.3\n","IRM_factor = 0.01\n","VarREx_factor = 3.0"]},{"cell_type":"markdown","source":["Initializtion"],"metadata":{"id":"UdmdODm3qAcz"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"IV-ozSeXeyV9"},"outputs":[],"source":["init_g = init_representation()\n","\n","ERM_g = init_representation()\n","IRM_g = init_representation()\n","CLoVE_g = init_representation()\n","VarREx_g = init_representation()\n","VarAUC_g = init_representation()"]},{"cell_type":"markdown","metadata":{"id":"W4-Of6yyhztG"},"source":["ERM"]},{"cell_type":"code","source":["ERM_g.set_weights(init_g.get_weights())"],"metadata":{"id":"gLGQP09m2yGf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["ERM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"HBsrZoVOmuV-"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"YRDiGqmphwCL"},"outputs":[],"source":["ERM_g, ERM_losses, ERM_Ns, ERM_test_aucs, ERM_val_aucs = training(ERM_optimizer, ERM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                  test_z1, test_z2, test_y, n_pairs, pos_per_class_train, ERM_factor, n_sim_envs, penalty_type=None)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"xLNZ-EGbMkQn"},"outputs":[],"source":["ERM_w = ERM_g.get_weights()\n","ERM_imp = np.abs(ERM_w[0]).sum(axis=1)/np.abs(ERM_w[0]).sum()\n","\n","ERM_auc_train = evaluate(ERM_g, train_z1, train_z2, train_y)\n","ERM_auc_val = evaluate(ERM_g, val_z1, val_z2, val_y)\n","ERM_auc_test = evaluate(ERM_g, test_z1, test_z2, test_y)\n","\n","ERM_auc_train, ERM_auc_val, ERM_auc_test\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(ERM_auc_train, ERM_auc_val, ERM_auc_test))"]},{"cell_type":"markdown","source":["IRM"],"metadata":{"id":"ljCdsE273OHp"}},{"cell_type":"code","source":["IRM_g.set_weights(init_g.get_weights())"],"metadata":{"id":"5DZ6r6xj3FLQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["IRM_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"NCwFI1v_3Fc0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["IRM_g, IRM_losses, IRM_Ns, IRM_test_aucs, IRM_val_aucs = training(IRM_optimizer, IRM_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, IRM_factor, n_sim_envs,\n","                                                                                 penalty_type='IRM')"],"metadata":{"id":"3b0o1jSW3TZM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["IRM_w = IRM_g.get_weights()\n","IRM_imp = np.abs(IRM_w[0]).sum(axis=1)/np.abs(IRM_w[0]).sum()\n","\n","IRM_auc_train = evaluate(IRM_g, train_z1, train_z2, train_y)\n","IRM_auc_val = evaluate(IRM_g, val_z1, val_z2, val_y)\n","IRM_auc_test = evaluate(IRM_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(IRM_auc_train, IRM_auc_val, IRM_auc_test))"],"metadata":{"id":"T3uwGG403Whe"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"T4WTPLGvO_bp"},"source":["CLoVE"]},{"cell_type":"code","source":["CLoVE_g.set_weights(init_g.get_weights())"],"metadata":{"id":"dmr3G0w-2mvZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["CLoVE_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"szPYwQKymok1"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"_PezdV-FGIpS"},"outputs":[],"source":["CLoVE_g, CLoVE_losses, CLoVE_Ns, CLoVE_test_aucs, CLoVE_val_aucs = training(CLoVE_optimizer, CLoVE_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                            test_z1, test_z2, test_y, n_pairs, pos_per_class_train, CLoVE_factor, n_sim_envs,\n","                                                                            penalty_type='CLoVE')"]},{"cell_type":"code","source":["CLoVE_w = CLoVE_g.get_weights()\n","CLoVE_imp = np.abs(CLoVE_w[0]).sum(axis=1)/np.abs(CLoVE_w[0]).sum()\n","\n","CLoVE_auc_train = evaluate(CLoVE_g, train_z1, train_z2, train_y)\n","CLoVE_auc_val = evaluate(CLoVE_g, val_z1, val_z2, val_y)\n","CLoVE_auc_test = evaluate(CLoVE_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(CLoVE_auc_train, CLoVE_auc_val, CLoVE_auc_test))"],"metadata":{"id":"M6T6V9wy4XwA"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["VarREx"],"metadata":{"id":"FNJbfiO94a8e"}},{"cell_type":"code","source":["VarREx_g.set_weights(init_g.get_weights())"],"metadata":{"id":"7inMVQLp4Zyg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarREx_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"ZQeKKtEa4Zfd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarREx_g, VarREx_losses, VarREx_Ns, VarREx_test_aucs, VarREx_val_aucs = training(VarREx_optimizer, VarREx_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarREx_factor, n_sim_envs,\n","                                                                                 penalty_type='VarREx')"],"metadata":{"id":"4KP_Qn-Q4ZV3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarREx_w = VarREx_g.get_weights()\n","VarREx_imp = np.abs(VarREx_w[0]).sum(axis=1)/np.abs(VarREx_w[0]).sum()\n","\n","VarREx_auc_train = evaluate(VarREx_g, train_z1, train_z2, train_y)\n","VarREx_auc_val = evaluate(VarREx_g, val_z1, val_z2, val_y)\n","VarREx_auc_test = evaluate(VarREx_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(VarREx_auc_train, VarREx_auc_val, VarREx_auc_test))"],"metadata":{"id":"NgbuTaky4Y5P"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PYdyvE3SJ4V5"},"source":["VarAUC"]},{"cell_type":"code","source":["VarAUC_g.set_weights(init_g.get_weights())"],"metadata":{"id":"IDnCQxUH2rIr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["VarAUC_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)"],"metadata":{"id":"nOuOW_NiZ6AX"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KbF7z4JoJ6Zd"},"outputs":[],"source":["VarAUC_g, VarAUC_losses, VarAUC_Ns, VarAUC_test_aucs, VarAUC_val_aucs = training(VarAUC_optimizer, VarAUC_g, z_train, c_train, train_envs, val_z1, val_z2, val_y,\n","                                                                                 test_z1, test_z2, test_y, n_pairs, pos_per_class_train, VarAUC_factor, n_sim_envs,\n","                                                                                 penalty_type='VarAUC')"]},{"cell_type":"code","source":["VarAUC_w = VarAUC_g.get_weights()\n","VarAUC_imp = np.abs(VarAUC_w[0]).sum(axis=1)/np.abs(VarAUC_w[0]).sum()\n","\n","VarAUC_auc_train = evaluate(VarAUC_g, train_z1, train_z2, train_y)\n","VarAUC_auc_val = evaluate(VarAUC_g, val_z1, val_z2, val_y)\n","VarAUC_auc_test = evaluate(VarAUC_g, test_z1, test_z2, test_y)\n","\n","print(\"Trainig data: {:.4f}, In-distribution: {:.4f}, Distribution-shift: {:.4f}\".format(VarAUC_auc_train, VarAUC_auc_val, VarAUC_auc_test))"],"metadata":{"id":"FYlSXvHBKh5i"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kaOsmCdFOk8r"},"source":["### Comparing results"]},{"cell_type":"code","source":["plt.plot(ERM_imp, '^', markersize=4, color='C3', label ='ERM', alpha=0.65)\n","plt.plot(IRM_imp, '*', color='C1', label='IRM', alpha=0.65)\n","plt.plot(CLoVE_imp, 'X', color='C0', label='CLOvE', alpha=0.65, markersize=5)\n","plt.plot(VarREx_imp, 'p', color='C4', label='VarREx', alpha=0.65, markersize=5)\n","plt.plot(VarAUC_imp, '.', markersize=8, color='C2', label='VarAUC', alpha=0.65)\n","\n","plt.axvline(v0_dim, color='k', linewidth=0.7)\n","plt.axvline(v0_dim + vplus_dim, color='k', linewidth=0.7)\n","plt.axvline(v0_dim + vplus_dim + vminus_dim, color='k', linewidth=0.7)\n","\n","plt.xticks([5, 15, 25])\n","plt.xlabel('Dimension')\n","plt.ylabel('Importance')\n","plt.legend();"],"metadata":{"id":"zQphIiggmbrQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plt.plot(ERM_Ns, ERM_val_aucs, '--', color='C3')\n","plt.plot(ERM_Ns, ERM_test_aucs,  label='ERM', color='C3')\n","\n","plt.plot(IRM_Ns, IRM_val_aucs, '--', color='C1', alpha=0.8)\n","plt.plot(IRM_Ns, IRM_test_aucs, label='IRM ', color='C1', alpha=0.8)\n","\n","plt.plot(CLoVE_Ns, CLoVE_val_aucs, '--', color='C0', alpha=0.8)\n","plt.plot(CLoVE_Ns, CLoVE_test_aucs, label='CLOvE', color='C0', alpha=0.8)\n","\n","plt.plot(VarREx_Ns, VarREx_val_aucs, '--', color='C4', alpha=0.8)\n","plt.plot(VarREx_Ns, VarREx_test_aucs, label='VarREx', color='C4', alpha=0.8)\n","\n","plt.plot(VarAUC_Ns, VarAUC_val_aucs, '--',  color='C2')\n","plt.plot(VarAUC_Ns, VarAUC_test_aucs, label='VarAUC', color='C2')\n","\n","\n","plt.xlabel('Data Points (pairs)')\n","plt.ylabel('AUC')\n","plt.legend(loc=4);"],"metadata":{"id":"0Wh8DLkYmbkd"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"provenance":[{"file_id":"1Yo2zIAh6D1-0nM7MOiDJQZKfJ9vqb3Or","timestamp":1697901542744},{"file_id":"16lKbBtyFaKmFk7RvwyzI_4wBwTGpF9-u","timestamp":1696172030568},{"file_id":"1mtQ7SZQU-svSsBO90LHAu-s7lxI0K_J0","timestamp":1696103554209},{"file_id":"15zpW4yGsGBVvRA7WReCdLXB8dJy4DW_7","timestamp":1696088714133},{"file_id":"1aCwJxhJazOxQl9EkUHBzaNEX9ePQdcA_","timestamp":1695981440686},{"file_id":"1d9FYUGBy2JZFpOb7kxLKLd0REUN8By2q","timestamp":1695914965719},{"file_id":"1OQR55crtvRtsCkICbdhJNBj4qnGbr6Me","timestamp":1695823964114},{"file_id":"1EiPLfNe8BIz5s3x_8ShUQZc0jC2iRw4H","timestamp":1695201530425},{"file_id":"1IkMrw6RJj6rsctSE6aXZwYMwIVBTUoGb","timestamp":1695120461090},{"file_id":"1app-Ikpz7ciNIY0Odzg7tFgNbCaAwHt8","timestamp":1694981579448},{"file_id":"1lcUrasMJAj89ULVmbyv0HN850cPZ1gv0","timestamp":1694941968925},{"file_id":"1mSYz-T4i6itmFUQy8uaHa5WKvXBgxPRn","timestamp":1694880923641},{"file_id":"1mOoW0BvZMNRum3Y30qrOkgohy9TqRkC7","timestamp":1694784592758},{"file_id":"1YorfIjlYD9GrKlfnbVtCH4oQQCxzfiwz","timestamp":1693906045347},{"file_id":"1vikBs4qVDYFssMEtbKIsN1BOcuEq4Tma","timestamp":1693828958876},{"file_id":"111lYqPXYzHyb7oXBP5PGg0og-u8_-Bqt","timestamp":1693234521949},{"file_id":"1-Woo69BgyMRS1_zTwvB_w2pbXCJMFwE0","timestamp":1692610678039},{"file_id":"1fuGUZSbL-7l-UoKiGI_ipuqJsl68z5ZH","timestamp":1683212300244}],"authorship_tag":"ABX9TyMdoLvxNdZmv38hVdffGGlF"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}