{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of implementing the code from JSE\n",
    "\n",
    "In this notebook, we briefly show how to implement the code in this repository for our Toy dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# import the necessary libraries\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "# import from the JSE package\n",
    "from JSE.data import *\n",
    "from JSE.settings import data_info, optimizer_info\n",
    "from JSE.models import *\n",
    "from JSE.training import *\n",
    "\n",
    "# import in order \n",
    "import argparse\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define several variables\n",
    "device_type = 'cpu'\n",
    "device = torch.device(device_type)\n",
    "dataset = 'Toy'\n",
    "\n",
    "# get the dataset information\n",
    "dataset_setting = 'default'\n",
    "dataset_settings = data_info[dataset][dataset_setting]\n",
    "optimizer_settings = optimizer_info['All']\n",
    "\n",
    "# Determine the spurious correlation for the dataset - in this case corresponding to the \\rho value in the paper\n",
    "spurious_ratio = 0.8\n",
    "\n",
    "# set the random seed, get the dataset, and set the device\n",
    "seed = 0\n",
    "set_seed(seed)\n",
    "\n",
    "# we will demean the data, apply no pca\n",
    "demean = True\n",
    "pca = False\n",
    "k_components = 20 # number of pca components - only relevant if pca = True\n",
    "d = 20 # number of features\n",
    "\n",
    "# define settings for the model\n",
    "solver = 'SGD'\n",
    "lr = 0.01\n",
    "weight_decay  = 0.0\n",
    "early_stopping = True\n",
    "epochs = 50\n",
    "per_step = 5 # number of epochs between printing the loss\n",
    "batch_size = 128\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the dataset object\n",
    "data_obj = get_dataset_obj(dataset, dataset_settings, spurious_ratio, data_info, seed, device, use_punctuation_MNLI=True)\n",
    "\n",
    "# demean, pca\n",
    "if demean:\n",
    "    data_obj.demean_X(reset_mean=True, include_test=True)\n",
    "if pca:\n",
    "    data_obj.transform_data_to_k_components(k_components, reset_V_k=True, include_test=True)\n",
    "    V_k_train = data_obj.V_k_train\n",
    "\n",
    "# get the data\n",
    "X_train, y_c_train, y_m_train = data_obj.X_train, data_obj.y_c_train, data_obj.y_m_train\n",
    "X_val, y_c_val, y_m_val = data_obj.X_val, data_obj.y_c_val, data_obj.y_m_val\n",
    "X_test, y_c_test, y_m_test = data_obj.X_test, data_obj.y_c_test, data_obj.y_m_test\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We first implement simple ERM\n",
    "\n",
    "# set the loaders of the dataset object\n",
    "data_obj.reset_X(X_train, X_val, batch_size=batch_size, reset_X_objects=False, include_weights=False, train_weights = None, val_weights = None)\n",
    "\n",
    "# get the model\n",
    "ERM_model = return_linear_model(d, \n",
    "                                               data_obj.main_loader,\n",
    "                                              device,\n",
    "                                              solver = solver,\n",
    "                                              lr=lr,\n",
    "                                              per_step=per_step,\n",
    "                                              tol = optimizer_settings['tol'],\n",
    "                                              early_stopping = early_stopping,\n",
    "                                              patience = optimizer_settings['patience'],\n",
    "                                              epochs = epochs,\n",
    "                                              bias=True,\n",
    "                                              weight_decay=weight_decay, \n",
    "                                              model_name=dataset+'_main_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Next, we show how to implement JSE\n",
    "\n",
    "# define the loaders\n",
    "balanced_training_concept = False\n",
    "concept_weights_train = None\n",
    "concept_weights_val = None\n",
    "concept_first = True # if True, then inner loop is for the main-task, outer loop is for the spurious concept-task\n",
    "loaders =  data_obj.create_loaders(batch_size=batch_size, workers=0, with_concept=True, include_weights=balanced_training_concept, train_weights=concept_weights_train, val_weights=concept_weights_val,concept_first=concept_first ) \n",
    "\n",
    "alpha = 0.05\n",
    "Delta = 0\n",
    "eval_balanced = True\n",
    "\n",
    "# Run the JSE algorithm - returns the spurious concept basis, the main task concept basis, and the dimension of both\n",
    "V_c, V_m, d_c, d_m = train_JSE(data_obj,\n",
    "                                                        device=device,\n",
    "                                                        batch_size=batch_size, \n",
    "                                                        solver=solver,\n",
    "                                                        lr=lr,\n",
    "                                                        per_step=per_step,\n",
    "                                                        tol=optimizer_settings['tol'],\n",
    "                                                        early_stopping=early_stopping,\n",
    "                                                        patience=optimizer_settings['patience'],\n",
    "                                                        epochs=epochs, \n",
    "                                                        Delta = Delta,\n",
    "                                                        alpha=alpha,\n",
    "                                                         null_is_concept = False,\n",
    "                                                         eval_balanced=eval_balanced, \n",
    "                                                         weight_decay=weight_decay,\n",
    "                                                         include_weights=balanced_training_concept,\n",
    "                                                         train_weights=concept_weights_train,\n",
    "                                                         val_weights=concept_weights_val,\n",
    "                                                         model_base_name='JSE_'+dataset,\n",
    "                                                         concept_first=concept_first,\n",
    "                                                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# define the loaders\n",
    "balanced_training_concept = False\n",
    "concept_weights_train = None\n",
    "concept_weights_val = None\n",
    "concept_first = True # if True, then inner loop is for the main-task, outer loop is for the spurious concept-task\n",
    "loaders =  data_obj.create_loaders(batch_size=batch_size, workers=0, with_concept=True, include_weights=balanced_training_concept, train_weights=concept_weights_train, val_weights=concept_weights_val,concept_first=concept_first ) \n",
    " \n",
    "rafvogel_with_joint = False \n",
    "orthogonality_constraint = False\n",
    "\n",
    "\n",
    "# Train the model to get V_c\n",
    "set_seed(seed)\n",
    "V_c_INLP, d_c_INLP = train_INLP(data_obj,\n",
    "                                                                device,\n",
    "                                                                batch_size=batch_size, \n",
    "                                                                solver=solver,\n",
    "                                                                lr=lr,\n",
    "                                                                weight_decay=weight_decay,\n",
    "                                                                per_step=per_step,\n",
    "                                                                tol=optimizer_settings['tol'],\n",
    "                                                                early_stopping=early_stopping,\n",
    "                                                                patience=optimizer_settings['patience'],\n",
    "                                                                epochs=epochs,\n",
    "                                                                alpha=alpha,\n",
    "                                                                model_base_name='waterbird_rafvogel',\n",
    "                                                                bias=True,\n",
    "                                                                joint_decision_rule=rafvogel_with_joint,\n",
    "                                                                include_weights=balanced_training_concept,\n",
    "                                                                train_weights=concept_weights_train,\n",
    "                                                                val_weights=concept_weights_val,\n",
    "                                                                orthogonality_constraint=orthogonality_constraint,\n",
    "                                                                expected_diff=None,\n",
    "                                                                var_diff = None,\n",
    "                                                                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the orthogonal projection matrix\n",
    "P_c_orth = torch.eye(d) - create_P(V_c) \n",
    "\n",
    "# reset the data\n",
    "X_train_transformed = torch.matmul(X_train, P_c_orth)\n",
    "X_val_transformed = torch.matmul(X_val, P_c_orth)\n",
    "X_test_transformed = torch.matmul(X_test, P_c_orth)\n",
    "\n",
    "# set the loaders of the dataset object\n",
    "balanced_training_main = False\n",
    "main_weights_train = None\n",
    "main_weights_val = None\n",
    "data_obj.reset_X(X_train_transformed, X_val_transformed, batch_size=batch_size, reset_X_objects=True, include_weights=balanced_training_main, train_weights = main_weights_train, val_weights = main_weights_val, only_main=True)\n",
    "\n",
    "# Train the model on the transformed embeddings\n",
    "set_seed(seed)\n",
    "ERM_after_JSE = return_linear_model(d, \n",
    "                                               data_obj.main_loader,\n",
    "                                              device,\n",
    "                                              solver = solver,\n",
    "                                              lr=lr,\n",
    "                                              per_step=per_step,\n",
    "                                              tol = optimizer_settings['tol'],\n",
    "                                              early_stopping = early_stopping,\n",
    "                                              patience = optimizer_settings['patience'],\n",
    "                                              epochs = epochs,\n",
    "                                              bias=True,\n",
    "                                              weight_decay=weight_decay, \n",
    "                                              model_name=dataset+'_main_model',\n",
    "                                              save_best_model=True\n",
    "                                              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the orthogonal projection matrix for INLP\n",
    "P_c_orth_INLP = torch.eye(d) - create_P(V_c_INLP) \n",
    "\n",
    "# reset the data\n",
    "X_train_transformed_INLP = torch.matmul(X_train, P_c_orth_INLP)\n",
    "X_val_transformed_INLP = torch.matmul(X_val, P_c_orth_INLP)\n",
    "X_test_transformed_INLP = torch.matmul(X_test, P_c_orth_INLP)\n",
    "\n",
    "# set the loaders of the dataset object\n",
    "balanced_training_main = False\n",
    "main_weights_train = None\n",
    "main_weights_val = None\n",
    "data_obj.reset_X(X_train_transformed_INLP, X_val_transformed_INLP, batch_size=batch_size, reset_X_objects=True, include_weights=balanced_training_main, train_weights = main_weights_train, val_weights = main_weights_val, only_main=True)\n",
    "\n",
    "# Train the model on the transformed embeddings\n",
    "set_seed(seed)\n",
    "ERM_after_INLP= return_linear_model(d, \n",
    "                                               data_obj.main_loader,\n",
    "                                              device,\n",
    "                                              solver = solver,\n",
    "                                              lr=lr,\n",
    "                                              per_step=per_step,\n",
    "                                              tol = optimizer_settings['tol'],\n",
    "                                              early_stopping = early_stopping,\n",
    "                                              patience = optimizer_settings['patience'],\n",
    "                                              epochs = epochs,\n",
    "                                              bias=True,\n",
    "                                              weight_decay=weight_decay, \n",
    "                                              model_name=dataset+'_main_model',\n",
    "                                              save_best_model=True\n",
    "                                              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Train the model on the transformed embeddings\n",
    "set_seed(seed)\n",
    "ERM_after_INLP= return_linear_model(d, \n",
    "                                               data_obj.main_loader,\n",
    "                                              device,\n",
    "                                              solver = solver,\n",
    "                                              lr=lr,\n",
    "                                              per_step=per_step,\n",
    "                                              tol = optimizer_settings['tol'],\n",
    "                                              early_stopping = early_stopping,\n",
    "                                              patience = optimizer_settings['patience'],\n",
    "                                              epochs = epochs,\n",
    "                                              bias=True,\n",
    "                                              weight_decay=weight_decay, \n",
    "                                              model_name=dataset+'_main_model',\n",
    "                                              save_best_model=True\n",
    "                                              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall Accuracy of JSE (test):  tensor(0.8325)\n",
      "Accuracy per group of JSE (test):                count      mean\n",
      "main concept                 \n",
      "0.0  0.0        509  0.842829\n",
      "     1.0        467  0.839400\n",
      "1.0  0.0        522  0.827586\n",
      "     1.0        502  0.820717\n"
     ]
    }
   ],
   "source": [
    "# get the accuracy of the main model after JSE\n",
    "y_m_pred_test = ERM_after_JSE(X_test_transformed)\n",
    "\n",
    "# get the accuracy of the main model overall after JSE\n",
    "main_acc_after = get_acc_pytorch_model(y_m_test, y_m_pred_test)\n",
    "\n",
    "# get the accuracy of the main model per group\n",
    "result_per_group, _ = get_acc_per_group(y_m_pred_test, y_m_test, y_c_test)\n",
    "\n",
    "print(\"Overall Accuracy of JSE (test): \", main_acc_after)\n",
    "\n",
    "print(\"Accuracy per group of JSE (test): \", result_per_group)\n",
    "\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall Accuracy of JSE (test):  tensor(0.5125)\n",
      "Accuracy per group of JSE (test):                count      mean\n",
      "main concept                 \n",
      "0.0  0.0        509  0.842829\n",
      "     1.0        467  0.203426\n",
      "1.0  0.0        522  0.187739\n",
      "     1.0        502  0.802789\n"
     ]
    }
   ],
   "source": [
    "# get the accuracy of the main model after JSE\n",
    "y_m_pred_test_INLP = ERM_after_INLP(X_test_transformed_INLP)\n",
    "\n",
    "# get the accuracy of the main model overall after JSE\n",
    "main_acc_after_INLP = get_acc_pytorch_model(y_m_test, y_m_pred_test_INLP)\n",
    "\n",
    "# get the accuracy of the main model per group\n",
    "result_per_group_INLP, _ = get_acc_per_group(y_m_pred_test_INLP, y_m_test, y_c_test)\n",
    "\n",
    "print(\"Overall Accuracy of JSE (test): \", main_acc_after_INLP)\n",
    "\n",
    "print(\"Accuracy per group of JSE (test): \", result_per_group_INLP)\n",
    "\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall Accuracy of ERM (test):  tensor(0.8170)\n",
      "Accuracy per group of ERM (test):                count      mean\n",
      "main concept                 \n",
      "0.0  0.0        509  0.899804\n",
      "     1.0        467  0.747323\n",
      "1.0  0.0        522  0.739464\n",
      "     1.0        502  0.878486\n"
     ]
    }
   ],
   "source": [
    "# get the accuracy of the main model after JSE\n",
    "y_m_pred_test_ERM = ERM_model(X_test)\n",
    "\n",
    "# get the accuracy of the main model overall after JSE\n",
    "main_acc_ERM = get_acc_pytorch_model(y_m_test, y_m_pred_test_ERM)\n",
    "\n",
    "# get the accuracy of the main model per group\n",
    "result_per_group_ERM, _ = get_acc_per_group(y_m_pred_test_ERM, y_m_test, y_c_test)\n",
    "\n",
    "print(\"Overall Accuracy of ERM (test): \", main_acc_ERM)\n",
    "\n",
    "print(\"Accuracy per group of ERM (test): \", result_per_group_ERM)\n",
    "\n",
    "   \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "JSE-replicate-code",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
