{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Experiments on semi-synthetic dataset with mask\n",
    "\"\"\"\n",
    "from MF import MF\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "def getAccuracy(real, predList):\n",
    "    returnList = []\n",
    "    for i in predList:\n",
    "        returnList.append(np.abs(real-i)/real)\n",
    "    return np.array(returnList)\n",
    "\n",
    "def calculate_g_ml_100k (x_train_user, x_train_item, num_user = 943, num_item = 1682, gamma = 353):\n",
    "    interference_matrix = np.zeros((num_user, num_item))\n",
    "    x = np.zeros((num_user, num_item))\n",
    "    for i in range(len(x_train_user)):\n",
    "            x[x_train_user[i],x_train_item[i]] = 1\n",
    "    for i in range(num_user):\n",
    "        interference_matrix[i,:] += np.sum(x[i,:])# + np.sum(x[:,j])\n",
    "    for j in range(num_item):\n",
    "        interference_matrix[:,j] += np.sum(x[:,j])\n",
    "    \n",
    "    interference_matrix -= x\n",
    "    \n",
    "    inter_matrix = np.ones((943,1682))\n",
    "    \n",
    "    for i in range(num_user):\n",
    "        m = interference_matrix[i,:] >= gamma\n",
    "        inter_matrix[i,:] *= m        \n",
    "    return inter_matrix\n",
    "\n",
    "file = open(\"data/synthetic_data\", \"rb\")\n",
    "ground_truth = pickle.load(file)\n",
    "one = pickle.load(file)\n",
    "three = pickle.load(file)\n",
    "four = pickle.load(file)\n",
    "rotate = pickle.load(file)\n",
    "skew = pickle.load(file)\n",
    "crs = pickle.load(file)\n",
    "a = pickle.load(file)\n",
    "file.close()\n",
    "\n",
    "\n",
    "#then compute g\n",
    "\n",
    "matrix = np.loadtxt(\"./data/u.data\", dtype=int)\n",
    "user = matrix[:, 0] - 1\n",
    "item = matrix[:, 1] - 1\n",
    "rating = matrix[:, 2]\n",
    "user_num = np.max(user)+1\n",
    "item_num = np.max(item)+1\n",
    "num_user = 943\n",
    "num_item = 1682\n",
    "\n",
    "matrix_g = calculate_g_ml_100k(user, item, gamma = 353)\n",
    "\n",
    "for i in range(len(user)):\n",
    "    if matrix_g[user[i], item[i]] == 1:\n",
    "         matrix[i, 3] = 1\n",
    "    else:\n",
    "         matrix[i, 3] = 0\n",
    "\n",
    "user_g1 = []        \n",
    "item_g1 = []        \n",
    "rating_g1 = []\n",
    "\n",
    "for i in range(len(user)):\n",
    "    if matrix[i, 3] == 1:\n",
    "        user_g1.append(user[i])\n",
    "        item_g1.append(item[i])\n",
    "        rating_g1.append(rating[i])\n",
    "\n",
    "total_num = len(user_g1)        \n",
    "\n",
    "user_g1 = np.array(user_g1)\n",
    "item_g1 = np.array(item_g1)\n",
    "rating_g1 = np.array(rating_g1)\n",
    "\n",
    "user_train, item_train, rating_train = user_g1[:int(total_num*0.9)], item_g1[:int(total_num*0.9)], rating_g1[:int(total_num*0.9)]\n",
    "user_test, item_test, rating_test = user_g1[int(total_num*0.9):], item_g1[int(total_num*0.9):], rating_g1[int(total_num*0.9):]       \n",
    "\n",
    "mf = MF(num_users=user_num, num_items=item_num, embedding_size=64)\n",
    "mf.fit(user_train, item_train, rating_train, user_test, item_test, rating_test, batch_size = 256, lamb=1e-3, gamma = 1e-4)  \n",
    "\n",
    "all_matrix = np.array([[x0, y0] for x0 in np.arange(user_num) for y0 in np.arange(item_num)]) # all user item pair\n",
    "user_all = all_matrix[:, 0] #\n",
    "item_all = all_matrix[:, 1] # \n",
    "#print(user_all)\n",
    "#print(item_all)\n",
    "rating_all = np.zeros(user_all.shape)\n",
    "prediction_g1 = mf.predict(user_all, item_all)\n",
    "\n",
    "total_num_g1 = prediction_g1.shape[0]\n",
    "index = np.argsort(prediction_g1)\n",
    "index_inverse = np.argsort(index)\n",
    "prediction = prediction_g1[index]\n",
    "prediction[np.where(prediction <= a[0])] = 1 \n",
    "prediction[np.intersect1d(np.where(prediction > a[0]), np.where(prediction <= a[1]))] = 2\n",
    "prediction[np.intersect1d(np.where(prediction > a[1]), np.where(prediction <= a[2]))] = 3\n",
    "prediction[np.intersect1d(np.where(prediction > a[2]), np.where(prediction <= a[3]))] = 4\n",
    "prediction[np.where(prediction > a[3])] = 5\n",
    "ground_truth_g1= prediction[index_inverse]\n",
    "print(ground_truth_g1[:50])\n",
    "\n",
    "\n",
    "user_g0 = []        \n",
    "item_g0 = []        \n",
    "rating_g0 = []\n",
    "\n",
    "for i in range(len(user)):\n",
    "    if matrix[i, 3] == 0:\n",
    "        user_g0.append(user[i])\n",
    "        item_g0.append(item[i])\n",
    "        rating_g0.append(rating[i])\n",
    "\n",
    "total_num = len(user_g0)          \n",
    "\n",
    "user_g0 = np.array(user_g0)\n",
    "item_g0 = np.array(item_g0)\n",
    "rating_g0 = np.array(rating_g0)\n",
    "\n",
    "user_train, item_train, rating_train = user_g0[:int(total_num*0.9)], item_g0[:int(total_num*0.9)], rating_g0[:int(total_num*0.9)]\n",
    "user_test, item_test, rating_test = user_g0[int(total_num*0.9):], item_g0[int(total_num*0.9):], rating_g0[int(total_num*0.9):]       \n",
    "\n",
    "mf = MF(num_users=user_num, num_items=item_num, embedding_size=64)\n",
    "mf.fit(user_train, item_train, rating_train, user_test, item_test, rating_test, batch_size = 256, lamb=1e-3, gamma = 1e-4)  \n",
    "\n",
    "prediction_g0 = mf.predict(user_all, item_all)\n",
    "\n",
    "total_num_g0 = prediction_g0.shape[0]\n",
    "index = np.argsort(prediction_g0)\n",
    "index_inverse = np.argsort(index)\n",
    "prediction = prediction_g0[index]\n",
    "prediction[np.where(prediction <= a[0])] = 1 \n",
    "prediction[np.intersect1d(np.where(prediction > a[0]), np.where(prediction <= a[1]))] = 2\n",
    "prediction[np.intersect1d(np.where(prediction > a[1]), np.where(prediction <= a[2]))] = 3\n",
    "prediction[np.intersect1d(np.where(prediction > a[2]), np.where(prediction <= a[3]))] = 4\n",
    "prediction[np.where(prediction > a[3])] = 5\n",
    "ground_truth_g0= prediction[index_inverse] \n",
    "print(ground_truth_g0[:50])\n",
    "\n",
    "# ratio = [0.53, 0.24, 0.14, 0.06, 0.03]\n",
    "# predList = [one, three, four, rotate, skew, crs]\n",
    "# store the experiment results with different mask ratios\n",
    "one_means_dr = []\n",
    "one_stds_dr = []\n",
    "one_means_inter_dr = []\n",
    "one_stds_inter_dr = []\n",
    "one_means_mrdr = []\n",
    "one_stds_mrdr = []\n",
    "one_means_inter_mrdr = []\n",
    "one_stds_inter_mrdr = []\n",
    "\n",
    "three_means_dr = []\n",
    "three_stds_dr = []\n",
    "three_means_inter_dr = []\n",
    "three_stds_inter_dr = []\n",
    "three_means_mrdr = []\n",
    "three_stds_mrdr = []\n",
    "three_means_inter_mrdr = []\n",
    "three_stds_inter_mrdr = []\n",
    "\n",
    "four_means_dr = []\n",
    "four_stds_dr = []\n",
    "four_means_inter_dr = []\n",
    "four_stds_inter_dr = []\n",
    "four_means_mrdr = []\n",
    "four_stds_mrdr = []\n",
    "four_means_inter_mrdr = []\n",
    "four_stds_inter_mrdr = []\n",
    "\n",
    "rotate_means_dr = []\n",
    "rotate_stds_dr = []\n",
    "rotate_means_inter_dr = []\n",
    "rotate_stds_inter_dr = []\n",
    "rotate_means_mrdr = []\n",
    "rotate_stds_mrdr = []\n",
    "rotate_means_inter_mrdr = []\n",
    "rotate_stds_inter_mrdr = []\n",
    "\n",
    "skew_means_dr = []\n",
    "skew_stds_dr = []\n",
    "skew_means_inter_dr = []\n",
    "skew_stds_inter_dr = []\n",
    "skew_means_mrdr = []\n",
    "skew_stds_mrdr = []\n",
    "skew_means_inter_mrdr = []\n",
    "skew_stds_inter_mrdr = []\n",
    "\n",
    "crs_means_dr = []\n",
    "crs_stds_dr = []\n",
    "crs_means_inter_dr = []\n",
    "crs_stds_inter_dr = []\n",
    "crs_means_mrdr = []\n",
    "crs_stds_mrdr = []\n",
    "crs_means_inter_mrdr = []\n",
    "crs_stds_inter_mrdr = []\n",
    "\n",
    "\n",
    "for k in [50,150,250,350]:\n",
    "    propensity = np.copy(ground_truth)\n",
    "    mask_user_ratio = np.array([(943 - k)/943] * 943)\n",
    "    mask_user = np.random.binomial(1, mask_user_ratio)\n",
    "    mask_item_ratio = np.array([(1682 - k)/1682] * 1682)\n",
    "    mask_item = np.random.binomial(1, mask_item_ratio)\n",
    "\n",
    "    mask_matrix = np.outer(mask_user, mask_item)\n",
    "    print(mask_matrix.shape)\n",
    "\n",
    "    alpha = 0.5\n",
    "\n",
    "    mask_user_number = np.sum(mask_user)\n",
    "    mask_item_number = np.sum(mask_item)\n",
    "    rest_number = 943*1682 - 943* (1682-mask_item_number) - 1682 * (943-mask_user_number) + (943-mask_user_number) * (1682 - mask_item_number)\n",
    "\n",
    "    p = (0.05 * 943 * 1682 / rest_number)/(0.53*(alpha ** 3) + 0.24 * (alpha ** 2) + 0.14 * alpha + 0.09)\n",
    "    print(p)\n",
    "    propensity[np.where(ground_truth == 5)] = p\n",
    "    propensity[np.where(ground_truth == 4)] = p \n",
    "    propensity[np.where(ground_truth == 3)] = p * alpha\n",
    "    propensity[np.where(ground_truth == 2)] = p * (alpha ** 2)\n",
    "    propensity[np.where(ground_truth == 1)] = p * (alpha ** 3)\n",
    "    res = np.zeros([6, 4])\n",
    "    res_var = np.zeros([6, 4])\n",
    "    propensity = propensity.reshape(943,1682)\n",
    "    propensity = mask_matrix * propensity\n",
    "    propensity = propensity.reshape(943*1682)\n",
    "    count = 0\n",
    "    for i in range(50):\n",
    "        print()\n",
    "        observation = np.random.binomial(1, propensity)\n",
    "\n",
    "        ones = np.count_nonzero(observation)\n",
    "        zeros = observation.shape[0] - ones\n",
    "        p_o = ones/(ones+zeros)\n",
    "        #print(p_o)\n",
    "        o = np.where(observation == 1)\n",
    "    \n",
    "        observation = observation.reshape(943, 1682)\n",
    "    \n",
    "        interference_matrix = np.zeros((num_user, num_item))\n",
    "        for i in range(num_user):\n",
    "            interference_matrix[i,:] += np.sum(observation[i,:])# + np.sum(x[:,j])\n",
    "        for j in range(num_item):\n",
    "            interference_matrix[:,j] += np.sum(observation[:,j])\n",
    "    \n",
    "        interference_matrix -= observation\n",
    "    \n",
    "        inter_matrix = np.ones((943,1682))\n",
    "    \n",
    "        for i in range(num_user):\n",
    "            m = interference_matrix[i,:] >= 115\n",
    "           # print(m)\n",
    "           # print()\n",
    "            inter_matrix[i,:] *= m        \n",
    "    \n",
    "        observation = observation.reshape(943 * 1682)\n",
    "        matrix_g_reshape = inter_matrix.reshape(943 * 1682)\n",
    "    \n",
    "        g1 = np.where(matrix_g_reshape == 1)\n",
    "        g0 = np.where(matrix_g_reshape == 0)\n",
    "    \n",
    "        p_g1 = np.sum(observation * matrix_g_reshape)/ (ones+zeros)\n",
    "        p_g0 = np.sum(observation * (1-matrix_g_reshape))/ (ones+zeros)\n",
    "\n",
    "        g_index = matrix_g_reshape[o]\n",
    "\n",
    "        ground_truth_g1_correspond = ground_truth_g1 * matrix_g_reshape\n",
    "        ground_truth_g0_correspond = ground_truth_g0 * (1-matrix_g_reshape)\n",
    "        ground_truth_g_full = ground_truth_g1_correspond + ground_truth_g0_correspond \n",
    "\n",
    "        ratio_5 = np.sum(g_index[np.where(ground_truth[o] == 5)])/len(np.where(ground_truth[o] == 5)[0]) # 98%\n",
    "        ratio_4 = np.sum(g_index[np.where(ground_truth[o] == 4)])/len(np.where(ground_truth[o] == 4)[0]) # 98%\n",
    "        ratio_3 = np.sum(g_index[np.where(ground_truth[o] == 3)])/len(np.where(ground_truth[o] == 3)[0]) # 94%\n",
    "        ratio_2 = np.sum(g_index[np.where(ground_truth[o] == 2)])/len(np.where(ground_truth[o] == 2)[0]) # 60%\n",
    "        ratio_1 = np.sum(g_index[np.where(ground_truth[o] == 1)])/len(np.where(ground_truth[o] == 1)[0]) # 25%\n",
    "        \n",
    "        if ratio_5 == 1 or ratio_4 == 1 or ratio_3 == 1 or ratio_2 == 1 or ratio_1 == 1:\n",
    "            continue\n",
    "        \n",
    "        count += 1\n",
    "        \n",
    "        propensity_new = np.copy(propensity)\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 5), g1)] = p * ratio_5\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 5), g0)] = p * (1-ratio_5)\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 4), g1)] = p * ratio_4\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 4), g0)] = p * (1-ratio_4)\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 3), g1)] = p * alpha * ratio_3\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 3), g0)] = p * alpha * (1-ratio_3)\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 2), g1)] = p * (alpha ** 2) * ratio_2\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 2), g0)] = p * (alpha ** 2) * (1-ratio_2)\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 1), g1)] = p * (alpha ** 3) * ratio_1\n",
    "        propensity_new[np.intersect1d(np.where(ground_truth == 1), g0)] = p * (alpha ** 3) * (1-ratio_1)\n",
    "        \n",
    "        \n",
    "        number_obs_1_ratio = len(np.where(ground_truth[o] == 1)[0])/ len(o[0])\n",
    "        number_obs_2_ratio = len(np.where(ground_truth[o] == 2)[0])/ len(o[0])\n",
    "        number_obs_3_ratio = len(np.where(ground_truth[o] == 3)[0])/ len(o[0])\n",
    "        number_obs_4_ratio = len(np.where(ground_truth[o] == 4)[0])/ len(o[0])\n",
    "        number_obs_5_ratio = len(np.where(ground_truth[o] == 5)[0])/ len(o[0])\n",
    "\n",
    "        difference_matrix = ground_truth_g1 == ground_truth_g0\n",
    "\n",
    "        propensity_copy = propensity.copy()\n",
    "        propensity_copy[np.where(propensity == 0)] = p_o\n",
    "        p_hat = 0.5/propensity_copy + 0.5/p_o\n",
    "        p_hat_g = np.zeros(len(propensity_new))\n",
    "        p_hat_g[g1] = 0.5/propensity_new[g1] + 0.5/p_g1\n",
    "        p_hat_g[g0] = 0.5/propensity_new[g0] + 0.5/p_g0\n",
    "        predList = [one, three, four, rotate, skew, crs]\n",
    "\n",
    "        c = 0.5\n",
    "        for j in range(6):\n",
    "            prediction = predList[j]\n",
    "            ce = c * abs(ground_truth_g1 - prediction) + (1 - c) * abs(ground_truth_g0 - prediction)\n",
    "        \n",
    "            # DR\n",
    "            prediction_hat = np.sum(prediction * p_hat * observation) / np.sum(observation * p_hat) # tilder\n",
    "            ce_hat = abs(prediction_hat- prediction)\n",
    "        \n",
    "            # DR Inter\n",
    "            prediction_hat1 = np.sum(prediction * matrix_g_reshape * p_hat_g * observation) / np.sum(observation * matrix_g_reshape * p_hat_g)\n",
    "            ce_hat1 = abs(prediction_hat1- prediction)\n",
    "\n",
    "            prediction_hat0 = np.sum(prediction * (1 - matrix_g_reshape) * p_hat_g * observation) / np.sum(\n",
    "                observation * (1 - matrix_g_reshape) * p_hat_g)\n",
    "        \n",
    "            ce_hat0 = abs(prediction_hat0- prediction)\n",
    "\n",
    "        \n",
    "            # MRDR\n",
    "            prediction_hat = np.sum(prediction * p_hat * p_hat * (1 - 1 / p_hat) * observation) / np.sum(\n",
    "                p_hat * p_hat * (1 - 1 / p_hat) * observation)\n",
    "\n",
    "            ce_mrdr = abs(prediction_hat- prediction) # MRDR e_hat\n",
    "        \n",
    "            # MRDR Inter\n",
    "        \n",
    "            prediction_hat1 = np.sum(prediction * matrix_g_reshape * p_hat_g * p_hat_g * (1 - 1 / p_hat_g) * observation) / np.sum(\n",
    "                observation * matrix_g_reshape * p_hat_g * p_hat_g * (1 - 1 / p_hat_g))\n",
    "        \n",
    "            ce_mrdr1 = abs(prediction_hat1- prediction)\n",
    "            \n",
    "            prediction_hat0 = np.sum(prediction * (1 - matrix_g_reshape) * p_hat_g * p_hat_g * (1 - 1 / p_hat_g) * observation) / np.sum(\n",
    "                observation * (1 - matrix_g_reshape) * p_hat_g * p_hat_g * (1 - 1 / p_hat_g))\n",
    "        \n",
    "            ce_mrdr0 = abs(prediction_hat0- prediction)\n",
    "        \n",
    "            real_ce = np.mean(ce)\n",
    "            naive_ce = np.sum(abs(ground_truth_g1 - prediction)* observation * matrix_g_reshape + \n",
    "                             abs(ground_truth_g0 - prediction) * observation * (1-matrix_g_reshape))/np.sum(observation)\n",
    "\n",
    "            ips_ce = np.mean(abs(ground_truth_g1 - prediction) * observation * matrix_g_reshape * p_hat + \n",
    "                             abs(ground_truth_g0 - prediction) * observation * (1-matrix_g_reshape) * p_hat)\n",
    "        \n",
    "            ips_ce_with_g = np.mean(c * abs(ground_truth_g1 - prediction) * observation * matrix_g_reshape * p_hat_g + \n",
    "                             (1-c) * abs(ground_truth_g0 - prediction) * (1-matrix_g_reshape) * observation * p_hat_g)\n",
    "        \n",
    "            dr_ce = np.mean(ce_hat + observation * matrix_g_reshape * (abs(ground_truth_g1 - prediction) - ce_hat) * p_hat +\n",
    "                           observation * (1 - matrix_g_reshape) * (abs(ground_truth_g0 - prediction) - ce_hat) * p_hat)\n",
    "        \n",
    "            dr_ce_with_g = np.mean(c*(ce_hat1 + observation * matrix_g_reshape * (abs(ground_truth_g1 - prediction) - ce_hat1) * p_hat_g) +\n",
    "                                  (1-c)*(ce_hat0 + observation * (1-matrix_g_reshape) * (abs(ground_truth_g0 - prediction) - ce_hat0) * p_hat_g))\n",
    "                                      \n",
    "            mrdr_ce = np.mean(ce_mrdr + observation * matrix_g_reshape * (abs(ground_truth_g1 - prediction) - ce_mrdr) * p_hat +\n",
    "                           observation * (1 - matrix_g_reshape) * (abs(ground_truth_g0 - prediction) - ce_mrdr) * p_hat)\n",
    "        \n",
    "            mrdr_ce_with_g = np.mean(c*(ce_mrdr1 + observation * matrix_g_reshape * (abs(ground_truth_g1 - prediction) - ce_mrdr1) * p_hat_g) +\n",
    "                                  (1-c)*(ce_mrdr0 + observation * (1-matrix_g_reshape) * (abs(ground_truth_g0 - prediction) - ce_mrdr0) * p_hat_g))\n",
    "            acc = getAccuracy(real_ce, [dr_ce, dr_ce_with_g, mrdr_ce, mrdr_ce_with_g])\n",
    "            res[j] += acc\n",
    "            res_var[j] += acc ** 2\n",
    "            \n",
    "            print(acc)\n",
    "\n",
    "    print()\n",
    "    loss_le = res/count\n",
    "    print(loss_le)\n",
    "    loss_std = np.sqrt((1/(count-1))*(res_var - count*(res/count)**2))\n",
    "    print(loss_std)\n",
    "    \n",
    "    # store the results\n",
    "    one_means_dr.append(loss_le[0][0])\n",
    "    one_stds_dr.append(loss_std[0][0])\n",
    "    one_means_inter_dr.append(loss_le[0][1])\n",
    "    one_stds_inter_dr.append(loss_std[0][1])\n",
    "    one_means_mrdr.append(loss_le[0][2])\n",
    "    one_stds_mrdr.append(loss_std[0][2])\n",
    "    one_means_inter_mrdr.append(loss_le[0][3])\n",
    "    one_stds_inter_mrdr.append(loss_std[0][3])\n",
    "\n",
    "    three_means_dr.append(loss_le[1][0])\n",
    "    three_stds_dr.append(loss_std[1][0])\n",
    "    three_means_inter_dr.append(loss_le[1][1])\n",
    "    three_stds_inter_dr.append(loss_std[1][1])\n",
    "    three_means_mrdr.append(loss_le[1][2])\n",
    "    three_stds_mrdr.append(loss_std[1][2])\n",
    "    three_means_inter_mrdr.append(loss_le[1][3])\n",
    "    three_stds_inter_mrdr.append(loss_std[1][3])\n",
    "\n",
    "    four_means_dr.append(loss_le[2][0])\n",
    "    four_stds_dr.append(loss_std[2][0])\n",
    "    four_means_inter_dr.append(loss_le[2][1])\n",
    "    four_stds_inter_dr.append(loss_std[2][1])\n",
    "    four_means_mrdr.append(loss_le[2][2])\n",
    "    four_stds_mrdr.append(loss_std[2][2])\n",
    "    four_means_inter_mrdr.append(loss_le[2][3])\n",
    "    four_stds_inter_mrdr.append(loss_std[2][3])\n",
    "\n",
    "    rotate_means_dr.append(loss_le[3][0])\n",
    "    rotate_stds_dr.append(loss_std[3][0])\n",
    "    rotate_means_inter_dr.append(loss_le[3][1])\n",
    "    rotate_stds_inter_dr.append(loss_std[3][1])\n",
    "    rotate_means_mrdr.append(loss_le[3][2])\n",
    "    rotate_stds_mrdr.append(loss_std[3][2])\n",
    "    rotate_means_inter_mrdr.append(loss_le[3][3])\n",
    "    rotate_stds_inter_mrdr.append(loss_std[3][3])\n",
    "\n",
    "    skew_means_dr.append(loss_le[4][0])\n",
    "    skew_stds_dr.append(loss_std[4][0])\n",
    "    skew_means_inter_dr.append(loss_le[4][1])\n",
    "    skew_stds_inter_dr.append(loss_std[4][1])\n",
    "    skew_means_mrdr.append(loss_le[4][2])\n",
    "    skew_stds_mrdr.append(loss_std[4][2])\n",
    "    skew_means_inter_mrdr.append(loss_le[4][3])\n",
    "    skew_stds_inter_mrdr.append(loss_std[4][3])\n",
    "\n",
    "    crs_means_dr.append(loss_le[5][0])\n",
    "    crs_stds_dr.append(loss_std[5][0])\n",
    "    crs_means_inter_dr.append(loss_le[5][1])\n",
    "    crs_stds_inter_dr.append(loss_std[5][1])\n",
    "    crs_means_mrdr.append(loss_le[5][2])\n",
    "    crs_stds_mrdr.append(loss_std[5][2])\n",
    "    crs_means_inter_mrdr.append(loss_le[5][3])\n",
    "    crs_stds_inter_mrdr.append(loss_std[5][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch-gpu]",
   "language": "python",
   "name": "conda-env-pytorch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
