{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### This code is based on the classes in  https://github.com/IBM/IRM-games"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instructions to run this notebook\n",
    "\n",
    "In this notebook, we present the comparisons for CS-MNIST: Covaraite shift based colored MNIST.. \n",
    "Run all the cells sequentially from top to bottom; we have commented the cells to help the reader."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import argparse\n",
    "import IPython.display as display\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow import keras\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.utils import shuffle\n",
    "import pandas as pd\n",
    "tf.compat.v1.enable_eager_execution()\n",
    "import cProfile\n",
    "from sklearn.model_selection import train_test_split\n",
    "import copy as cp\n",
    "from sklearn.model_selection import KFold\n",
    "from datetime import date\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate CS-CMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class assemble_data_mnist_sb:\n",
    "    def __init__(self, n_tr):\n",
    "        \n",
    "        D=tf.keras.datasets.mnist.load_data()\n",
    "        n_tr_total = D[0][0].shape[0]\n",
    "        ind_tr =np.random.choice(n_tr_total, n_tr)\n",
    "        x_train=D[0][0][ind_tr].astype('float32')\n",
    "        x_test=D[1][0].astype('float32')\n",
    "        num_train=x_train.shape[0]\n",
    "        self.x_train_mnist=x_train.reshape((num_train,28,28,1))\n",
    "        self.y_train_mnist=D[0][1][ind_tr].reshape((num_train,1))\n",
    "        num_test=x_test.shape[0]\n",
    "        self.x_test_mnist=x_test.reshape((num_test,28,28,1))\n",
    "        self.y_test_mnist=D[1][1].reshape((num_test,1))\n",
    "        self.n_tr = n_tr\n",
    "\n",
    "    def create_environment(self,env_index,x,y,prob_e,prob_label): \n",
    "        # prob_label we retain from other classes for simplicity but is not relevant for this class\n",
    "        \n",
    "        #Convert y>5 to 1 and y<5 to 0.\n",
    "        y= (y>=5).astype(int) # binarize the digit labels\n",
    "        num_samples=len(y)\n",
    "        \n",
    "        z_color = np.random.binomial(1,0.5,(num_samples,1)) # sample color for each sample\n",
    "        w_comb  =1- np.logical_xor(y, z_color)              # compute xor of label and color and negate it\n",
    "\n",
    "        selection_0 = np.where(w_comb==0)[0]                # indices where -xor is zero\n",
    "        selection_1 = np.where(w_comb==1)[0]                # indices were -xor is one\n",
    "        \n",
    "        ns0 = np.shape(selection_0)[0]\n",
    "        ns1 = np.shape(selection_1)[0]\n",
    "        \n",
    "        final_selection_0 = selection_0[np.where(np.random.binomial(1,prob_e,(ns0,1))==1)[0]]   # -xor =0 then select that point with probability prob_e\n",
    "        final_selection_1 = selection_1[np.where(np.random.binomial(1,1-prob_e,(ns1,1))==1)[0]] # -xor =0 then select that point with probability 1-prob_e\n",
    "\n",
    "        final_selection = np.concatenate((final_selection_0, final_selection_1), axis=0)        # indices of the final set of points selected\n",
    "        \n",
    "        \n",
    "        z_color_final = z_color[final_selection]  # colors of the final set of selected points\n",
    "        y=y[final_selection]                      # labels of the final set of selected points\n",
    "        x= x[final_selection]                     # gray scale image of the final set of selected points\n",
    "        \n",
    "        ### color the points x based on z_color_final \n",
    "        \n",
    "        red = np.where(z_color_final==0)[0]       # select the points with z_color_final=0 to set them to red color\n",
    "        green = np.where(z_color_final==1)[0]     # select the points with z_color_final=1 to set them to green color\n",
    "        \n",
    "        num_samples_final = np.shape(y)[0]\n",
    "\n",
    "        tsh = 0.5\n",
    "        chR = cp.deepcopy(x[red,:])\n",
    "        chR[chR > tsh] = 1\n",
    "        chG = cp.deepcopy(x[red,:])\n",
    "        chG[chG > tsh] = 0\n",
    "        chB = cp.deepcopy(x[red,:])\n",
    "        chB[chB > tsh] = 0\n",
    "        r = np.concatenate((chR, chG, chB), axis=3)\n",
    "\n",
    "        tsh= 0.5\n",
    "        chR1= cp.deepcopy(x[green,:])\n",
    "        chR1[chR1 > tsh] = 0\n",
    "        chG1= cp.deepcopy(x[green,:])\n",
    "        chG1[chG1 > tsh] = 1\n",
    "        chB1= cp.deepcopy(x[green,:])\n",
    "        chB1[chB1 > tsh] = 0\n",
    "        g= np.concatenate((chR1, chG1, chB1), axis=3)\n",
    "\n",
    "\n",
    "        dataset=np.concatenate((r,g),axis=0)\n",
    "        labels=np.concatenate((y[red,:],y[green,:]),axis=0)\n",
    "\n",
    "        return (dataset,labels,np.ones((num_samples_final,1))*env_index, z_color_final)\n",
    "    \n",
    "    \n",
    "\n",
    "    \n",
    "    def create_training_data(self, n_e, corr_list, p_label_list):\n",
    "        x_train_mnist = self.x_train_mnist\n",
    "        y_train_mnist = self.y_train_mnist\n",
    "        n_tr = self.n_tr\n",
    "        ind_X = range(0,n_tr)\n",
    "        kf = KFold(n_splits=n_e, shuffle=True)\n",
    "        l=0\n",
    "        ind_list =[]\n",
    "        for train, test in kf.split(ind_X):\n",
    "            ind_list.append(test)\n",
    "            l=l+1   \n",
    "        data_tuple_list = []\n",
    "        for l in range(n_e):\n",
    "            data_tuple_list.append(self.create_environment(l,x_train_mnist[ind_list[l],:,:,:],y_train_mnist[ind_list[l],:],corr_list[l],p_label_list[l]))\n",
    "\n",
    "        self.data_tuple_list = data_tuple_list\n",
    "            \n",
    "    def create_testing_data(self, corr_test, prob_label, n_e):\n",
    "        x_test_mnist = self.x_test_mnist\n",
    "        y_test_mnist = self.y_test_mnist        \n",
    "        (x_test,y_test,e_test,z_color_test)=self.create_environment(n_e,x_test_mnist,y_test_mnist,corr_test,prob_label)\n",
    "\n",
    "        self.data_tuple_test = (x_test,y_test, e_test)\n",
    "        \n",
    "        \n",
    "    def create_testing_data_blue(self,  n_e):\n",
    "        x_test_mnist = self.x_test_mnist\n",
    "        y_test_mnist = self.y_test_mnist        \n",
    "        (x_test,y_test,e_test)=self.create_environment_blue(n_e,x_test_mnist,y_test_mnist)\n",
    "\n",
    "        self.data_tuple_test_blue = (x_test,y_test, e_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Class for ERM model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class standard_erm_model:\n",
    "    def __init__(self, model,  num_epochs, batch_size, learning_rate):\n",
    "        \n",
    "        self.model        = model\n",
    "        self.num_epochs   = num_epochs\n",
    "        self.batch_size   = batch_size\n",
    "        self.learning_rate=learning_rate\n",
    "    \n",
    "    def fit(self, data_tuple_list):\n",
    "        learning_rate = self.learning_rate\n",
    "        num_epochs = self.num_epochs\n",
    "        n_e  = len(data_tuple_list)\n",
    "        x_in = data_tuple_list[0][0]\n",
    "        for i in range(1,n_e):\n",
    "            x_c = data_tuple_list[i][0]\n",
    "            x_in = np.concatenate((x_in, x_c), axis=0)\n",
    "        y_in = data_tuple_list[0][1]\n",
    "        for i in range(1,n_e):\n",
    "            y_c = data_tuple_list[i][1]\n",
    "            y_in = np.concatenate((y_in, y_c), axis=0)\n",
    "        e_in = data_tuple_list[0][2]\n",
    "        for i in range(1,n_e):\n",
    "            e_c = data_tuple_list[i][2]\n",
    "            e_in = np.concatenate((e_in, e_c), axis=0) \n",
    " \n",
    "\n",
    "\n",
    "    ### fit the model\n",
    "        model = self.model\n",
    "        batch_size = self.batch_size\n",
    "\n",
    "\n",
    "        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n",
    "                  loss='sparse_categorical_crossentropy',\n",
    "                  metrics=['accuracy'])\n",
    "\n",
    "        model.fit(x_in, y_in,  epochs=num_epochs, batch_size=batch_size)\n",
    "  \n",
    "        self.x_in      = x_in\n",
    "        self.y_in      = y_in\n",
    "        self.e_in      = e_in\n",
    "        self.model     = model\n",
    "\n",
    "\n",
    "    \n",
    "    def evaluate(self, data_tuple_test):\n",
    "        ##### evaluations jmtd\n",
    "        x_test = data_tuple_test[0]\n",
    "        y_test = data_tuple_test[1]\n",
    "        x_in   = self.x_in\n",
    "        y_in   = self.y_in\n",
    "        \n",
    "        model = self.model\n",
    "        train_accuracy= tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "        test_accuracy= tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "\n",
    "        ytr_ = model.predict(x_in)\n",
    "        train_acc =  np.float(train_accuracy(y_in, ytr_))\n",
    "\n",
    "        yts_ = model.predict(x_test)\n",
    "\n",
    "        test_acc  =  np.float(test_accuracy(y_test, yts_))\n",
    "        \n",
    "        self.train_acc = train_acc\n",
    "        self.test_acc  = test_acc\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Class for IRM model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class irm_model:\n",
    "    def __init__(self, model, learning_rate, batch_size, steps_max, steps_threshold, gamma_new):\n",
    "        self.model= model                                 # initialized model passed\n",
    "        self.learning_rate = learning_rate                # learning rate for Adam optimizer\n",
    "        self.batch_size    = batch_size                   # batch size per gradient update\n",
    "        self.steps_max     = steps_max                    # maximum number of gradient steps\n",
    "        self.steps_threshold = steps_threshold            # threshold on the number of steps after which we use penalty gamma_new\n",
    "        self.gamma_new      = gamma_new                   # penalty value; note penalty is set to 1 initially and gamma_new only kicks in after steps exceeed steps_threshold\n",
    "    \n",
    "    def fit(self,data_tuple_list):\n",
    "        \n",
    "        n_e  = len(data_tuple_list)                       # number of environments\n",
    "        # combine the data from the different environments, x_in: combined data (features) from different environments\n",
    "        x_in = data_tuple_list[0][0]\n",
    "        for i in range(1,n_e):\n",
    "            x_c = data_tuple_list[i][0]\n",
    "            x_in = np.concatenate((x_in, x_c), axis=0)\n",
    "        y_in = data_tuple_list[0][1]\n",
    "        for i in range(1,n_e):\n",
    "            y_c = data_tuple_list[i][1]\n",
    "            y_in = np.concatenate((y_in, y_c), axis=0)\n",
    "        e_in = data_tuple_list[0][2]\n",
    "        for i in range(1,n_e):\n",
    "            e_c = data_tuple_list[i][2]\n",
    "            e_in = np.concatenate((e_in, e_c), axis=0) \n",
    "\n",
    "        z_in = data_tuple_list[0][2]\n",
    "        for i in range(1,n_e):\n",
    "            z_c = data_tuple_list[i][2]\n",
    "            z_in = np.concatenate((z_in, z_c), axis=0) \n",
    "            \n",
    "\n",
    "        self.frac = 0.8\n",
    "        frac = self.frac\n",
    "        n_tr = np.int(np.shape(x_in)[0]*frac)\n",
    "        x_in_tr  = x_in[:n_tr]\n",
    "        x_in_val = x_in[n_tr:]\n",
    "        y_in_tr  = y_in[:n_tr]\n",
    "        y_in_val = y_in[n_tr:]   \n",
    "        e_in_tr  = e_in[:n_tr]\n",
    "        e_in_val = e_in[n_tr:]          \n",
    "        \n",
    "        x_in = x_in_tr\n",
    "        y_in = y_in_tr\n",
    "        e_in = e_in_tr        \n",
    "        \n",
    "        \n",
    "        self.x_in = x_in\n",
    "        self.y_in = y_in\n",
    "        self.e_in = e_in\n",
    "\n",
    "        self.x_in_val = x_in_val\n",
    "        self.y_in_val = y_in_val\n",
    "        self.e_in_val = e_in_val\n",
    "\n",
    "\n",
    "        \n",
    "        # cross entropy (we do not use the cross entropy from keras because there are issues when computing gradient of the gradient)\n",
    "        def cross_entropy_manual(y,y_pred):\n",
    "            y_p   = tf.math.log(tf.nn.softmax(y_pred))\n",
    "            n_p   = np.float(tf.shape(y_p)[0])\n",
    "            ind_0 = tf.where(y==0)[:,0]\n",
    "            ind_1 = tf.where(y==1)[:,0]\n",
    "            y_p0  = tf.gather(y_p, ind_0)[:,0]\n",
    "            y_p1  = tf.gather(y_p, ind_1)[:,1]\n",
    "            ent_0 = tf.reduce_sum(y_p0)\n",
    "            ent_1 = tf.reduce_sum(y_p1)\n",
    "            total = -(ent_0 + ent_1)/n_p\n",
    "            return total\n",
    "\n",
    "        # cross entropy loss for environment e\n",
    "        def loss_n(model,x,e,y,w,k):\n",
    "            index = np.where(e==k)\n",
    "            y1_ = model(x[index[0]])*w\n",
    "            y1  = y[index[0]]\n",
    "\n",
    "            return cross_entropy_manual(y1,y1_)   \n",
    "\n",
    "        # gradient of cross entropy loss w.r.t w for environment e\n",
    "        def grad_norm_n(model,x,e,y,w,k):\n",
    "            with tf.GradientTape() as g:\n",
    "                g.watch(w)\n",
    "                loss_value = loss_n(model,x,e,y,w,k)  \n",
    "            return g.gradient(loss_value, w)**2\n",
    "\n",
    "        # total cross entropy loss across all environments    \n",
    "        def loss_0(model, x,e,y,w):\n",
    "            y_ =model(x)\n",
    "            loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "            return loss_object(y_true=y, y_pred=y_)\n",
    "        \n",
    "        # sum of cross entropy loss and penalty \n",
    "        def loss_total(model,x,e,y,w,gamma, n_e):\n",
    "            loss0 = loss_0(model,x,e,y,w)\n",
    "            loss_penalty = 0.0\n",
    "            for k in range(n_e):\n",
    "                loss_penalty += gamma*grad_norm_n(model,x,e,y,w,k)\n",
    "\n",
    "            return (loss0 + loss_penalty)*(1/gamma)        \n",
    "        \n",
    "        # gradient of sum of cross entropy loss and penalty w.r.t model parameters\n",
    "        def grad_total_n(model,x,e,y,w,gamma, n_e):\n",
    "            with tf.GradientTape() as tape:\n",
    "                loss_value = loss_total(model,x,e,y,w,gamma,n_e)\n",
    "            return loss_value, tape.gradient(loss_value, model.trainable_variables)\n",
    "        \n",
    "        model = self.model\n",
    "        learning_rate = self.learning_rate\n",
    "        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
    "        \n",
    "        ## train \n",
    "        train_loss_results = []\n",
    "        train_accuracy_results = []\n",
    "        flag = 'false'\n",
    "        batch_size = self.batch_size\n",
    "        num_examples= x_in.shape[0]\n",
    "        gamma     = 1.0\n",
    "        w         = tf.constant(1.0)\n",
    "        steps     = 0\n",
    "        steps_max  = self.steps_max\n",
    "        steps_threshold = self.steps_threshold\n",
    "        gamma_new       = self.gamma_new\n",
    "        \n",
    "        while(steps<=steps_max):\n",
    "            (xt,yt,et)   = shuffle(x_in,y_in,e_in)\n",
    "            epoch_loss_avg = tf.keras.metrics.Mean()\n",
    "            epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "            count = 0\n",
    "            if(steps>=steps_threshold):\n",
    "                gamma = gamma_new\n",
    "            for offset in range(0,num_examples, batch_size):\n",
    "                if(steps<=steps_max):\n",
    "                    end = offset + batch_size\n",
    "                    \n",
    "                    batch_x, batch_y, batch_e = xt[offset:end,:], yt[offset:end,:], et[offset:end,:]\n",
    "                    \n",
    "                    loss_values, grads = grad_total_n(model,batch_x,batch_e,batch_y,w,gamma,n_e)\n",
    "                    \n",
    "                    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "                    \n",
    "                    epoch_loss_avg(loss_values)    \n",
    "                    epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "                    acc_train = np.float(epoch_accuracy(y_in, model(x_in)))\n",
    "                    train_loss_results.append(epoch_loss_avg.result())\n",
    "                    train_accuracy_results.append(epoch_accuracy.result())\n",
    "                    \n",
    "                count = count +1 \n",
    "                steps = steps + 1\n",
    "                \n",
    "        self.model = model\n",
    "            \n",
    "    def evaluate(self, data_tuple_test):\n",
    "        x_test = data_tuple_test[0]\n",
    "        y_test = data_tuple_test[1]\n",
    "        \n",
    "        x_in   = self.x_in\n",
    "        y_in   = self.y_in\n",
    "        \n",
    "        x_in_val   = self.x_in_val\n",
    "        y_in_val   = self.y_in_val\n",
    "        \n",
    "        train_accuracy= tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "        val_accuracy  = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n",
    "        \n",
    "        model = self.model\n",
    "        ytr_  = model.predict(x_in)\n",
    "        train_acc =  np.float(train_accuracy(y_in, ytr_))\n",
    "        self.train_acc = train_acc\n",
    "        \n",
    "        if(self.frac<1.0):\n",
    "            yv_ = model.predict(x_in_val)\n",
    "            val_acc =  np.float(train_accuracy(y_in_val, yv_))\n",
    "            self.val_acc   = train_acc\n",
    "\n",
    "        \n",
    "        yts_ = model.predict(x_test)\n",
    "        test_acc =  np.float(test_accuracy(y_test, yts_))\n",
    "        self.test_acc  = test_acc\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample complexity on CS-CMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "n_trial =10\n",
    "n_tr_list = [1000, 5000, 10000, 30000, 60000]\n",
    "\n",
    "k=0\n",
    "K = len(n_tr_list)\n",
    "ERM_model_acc = np.zeros((K,n_trial))\n",
    "ERM_model_acc_nb = np.zeros((K,n_trial))\n",
    "IRM_model_acc = np.zeros((K,n_trial))\n",
    "IRM_model_acc_v = np.zeros((K,n_trial))\n",
    "IRM_model_ind_v = np.zeros((K,n_trial))\n",
    "\n",
    "ERM_model_acc1 = np.zeros((K,n_trial))\n",
    "ERM_model_acc1_nb = np.zeros((K,n_trial))\n",
    "IRM_model_acc1 = np.zeros((K,n_trial))\n",
    "IRM_model_acc1_v = np.zeros((K,n_trial))\n",
    "IRM_model_ind_v = np.zeros((K,n_trial))\n",
    "\n",
    "ERM_model_acc_av = np.zeros(K)\n",
    "ERM_model_acc_av_nb = np.zeros(K)\n",
    "IRM_model_acc_av = np.zeros(K)\n",
    "IRM_model_acc_av_v = np.zeros(K)\n",
    "\n",
    "\n",
    "ERM_model_acc_av1 = np.zeros(K)\n",
    "ERM_model_acc_av1_nb = np.zeros(K)\n",
    "IRM_model_acc_av1 = np.zeros(K)\n",
    "IRM_model_acc_av1_v = np.zeros(K)\n",
    "\n",
    "list_params = []\n",
    "for n_tr in n_tr_list:\n",
    "    print (\"tr\" + str(n_tr))\n",
    "    t_start = time.time()\n",
    "    for trial in range(n_trial):\n",
    "        print (\"trial \" + str(trial))\n",
    "        n_e=2\n",
    "        p_color_list = [0.2, 0.1]\n",
    "        p_label_list = [0.25]*n_e\n",
    "        D = assemble_data_mnist_sb(n_tr) # initialize mnist digits data object\n",
    "\n",
    "        D.create_training_data(n_e, p_color_list, p_label_list) # creates the training environments\n",
    "\n",
    "        p_label_test = 0.25 # probability of switching pre-label in test environment\n",
    "        p_color_test = 0.9  # probability of switching the final label to obtain the color index in test environment\n",
    "\n",
    "        D.create_testing_data(p_color_test, p_label_test, n_e)  # sets up the testing environment\n",
    "        (num_examples_environment,length, width, height) = D.data_tuple_list[0][0].shape # attributes of the data\n",
    "        num_classes = len(np.unique(D.data_tuple_list[0][1])) # number of classes in the data\n",
    "\n",
    "        model_erm =  keras.Sequential([\n",
    "                keras.layers.Flatten(input_shape=(length,width,height)),\n",
    "                keras.layers.Dense(390, activation = 'relu',kernel_regularizer=keras.regularizers.l2(0.0011)),\n",
    "                keras.layers.Dense(390, activation='relu',kernel_regularizer=keras.regularizers.l2(0.0011)),\n",
    "                keras.layers.Dense(2, activation='softmax')\n",
    "        ])\n",
    "        \n",
    "        num_epochs = 100\n",
    "        batch_size = 512\n",
    "        learning_rate = 4.9e-4\n",
    "        erm_model1 = standard_erm_model(model_erm, num_epochs, batch_size, learning_rate)\n",
    "        erm_model1.fit(D.data_tuple_list)\n",
    "        erm_model1.evaluate(D.data_tuple_test)\n",
    "        print (\"Training accuracy:\" + str(erm_model1.train_acc))\n",
    "        print (\"Testing accuracy:\" + str(erm_model1.test_acc))\n",
    "        \n",
    "        ERM_model_acc[k][trial] = erm_model1.test_acc\n",
    "        ERM_model_acc1[k][trial] = erm_model1.train_acc\n",
    "\n",
    "\n",
    "        gamma_list = [10000, 33000, 66000,100000.0]\n",
    "        index=0\n",
    "        best_err = 1e6\n",
    "        train_list =[]\n",
    "        val_list = []\n",
    "        test_list = []\n",
    "        for gamma_new in gamma_list:\n",
    "\n",
    "            model_irm = keras.Sequential([\n",
    "                                keras.layers.Flatten(input_shape=(length,width,height)),\n",
    "                                keras.layers.Dense(390, activation = 'relu',kernel_regularizer=keras.regularizers.l2(0.0011)),\n",
    "                                keras.layers.Dense(390, activation='relu',kernel_regularizer=keras.regularizers.l2(0.0011)),\n",
    "                                keras.layers.Dense(num_classes)\n",
    "                        ])\n",
    "            batch_size       = 512\n",
    "            steps_max        = 1000\n",
    "            steps_threshold  = 190  ## threshold after which gamma_new is used\n",
    "            learning_rate    = 4.9e-4\n",
    "\n",
    "\n",
    "            irm_model1 = irm_model(model_irm, learning_rate, batch_size, steps_max, steps_threshold, gamma_new)\n",
    "            irm_model1.fit(D.data_tuple_list)\n",
    "            irm_model1.evaluate(D.data_tuple_test)\n",
    "            error_val = 1-irm_model1.val_acc\n",
    "            train_list.append(irm_model1.train_acc)\n",
    "            val_list.append(irm_model1.val_acc)\n",
    "            test_list.append(irm_model1.test_acc)\n",
    "            if(error_val<best_err):\n",
    "                index_best =index\n",
    "                best_err = error_val\n",
    "            index= index+1\n",
    "\n",
    "        print (\"Training accuracy:\" + str(train_list[index_best]))\n",
    "        print (\"Validation accuracy:\" + str(val_list[index_best]))\n",
    "        print (\"Testing accuracy:\" + str(test_list[index_best]))\n",
    "\n",
    "        IRM_model_acc_v[k][trial]  = test_list[index_best]\n",
    "        IRM_model_acc1_v[k][trial] = train_list[index_best]\n",
    "        IRM_model_ind_v[k][trial]  = index_best\n",
    "\n",
    "\n",
    "    IRM_model_acc_av_v[k] = np.mean(IRM_model_acc_v[k])\n",
    "    list_params.append([n_tr,\"IRMv_test\", np.mean(IRM_model_acc_v[k]),np.std(IRM_model_acc_v[k])])\n",
    "\n",
    "    ERM_model_acc_av[k] = np.mean(ERM_model_acc[k])\n",
    "    list_params.append([n_tr,\"ERM_test\", np.mean(ERM_model_acc[k]),np.std(ERM_model_acc[k])])\n",
    "\n",
    "\n",
    "    IRM_model_acc_av1_v[k] = np.mean(IRM_model_acc1_v[k])\n",
    "    list_params.append([n_tr,\"IRMv_train\", np.mean(IRM_model_acc1_v[k]),np.std(IRM_model_acc1_v[k])])\n",
    "    \n",
    "    ERM_model_acc_av1[k] = np.mean(ERM_model_acc1[k])\n",
    "    list_params.append([n_tr, \"ERM_train\", np.mean(ERM_model_acc1[k]),np.std(ERM_model_acc1[k])])\n",
    "\n",
    "\n",
    "    k=k+1\n",
    "\n",
    "    t_end = time.time()\n",
    "    print(\"total time: \" + str(t_end-t_start))\n",
    "    \n",
    "\n",
    "\n",
    "results = pd.DataFrame(list_params, columns= [\"Sample\",\"Method\", \"Performance\", \"Sdev\"])\n",
    "ideal_error = np.zeros(5)\n",
    "\n",
    "print (\"end\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.xlabel(\"Number of samples\", fontsize=16)\n",
    "plt.ylabel(\"Test error\", fontsize=16)\n",
    "plt.plot(n_tr_list, 1-ERM_model_acc_av, \"-r\", marker=\"+\", label=\"ERM\")\n",
    "plt.plot(n_tr_list, 1-IRM_model_acc_av_v, \"-b\", marker=\"s\",label=\"IRMv1\")\n",
    "plt.plot(n_tr_list, ideal_error, \"-g\", marker=\"x\", label=\"Optimal invariant\")\n",
    "plt.legend(loc=\"upper left\", fontsize=18)\n",
    "plt.ylim(-0.01,0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
