{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from src.utils.training import get_objective, get_optimizer\n",
    "from src.utils.data import load_dataset\n",
    "from scipy.optimize import minimize\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from src.optim.SOREL_batch import SorelBatch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"amazon\"\n",
    "\n",
    "X_train, y_train, X_val, y_val = load_dataset(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get Objective and Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build objective.\n",
    "model_cfg = {\n",
    "    \"objective\": \"cvar\", # Options: 'cvar', 'extremile', 'esrm'. You can also append '_lite' and '_hard' for a less/more skewed spectrum.\n",
    "    \"l2_reg\": 1.0,\n",
    "    \"para_value\": 0.25,  # to apply alpha-cvar, rho-esrm or r-extremile,  set para_value = 1-alpha (cvar) or rho (esrm) or r (extremile).\n",
    "    \"loss\": \"multinomial_cross_entropy\",  # Options: 'squared_error', 'binary_cross_entropy', 'multinomial_cross_entropy'.\n",
    "    # \"n_class\": torch.max(y_train) + 1,\n",
    "    \"n_class\": 5,\n",
    "    \"shift_cost\": 0,\n",
    "}\n",
    "autodiff = False # non-autodiff variants\n",
    "train_obj = get_objective(model_cfg, X_train, y_train, autodiff=autodiff)\n",
    "val_obj   = get_objective(model_cfg, X_val, y_val, autodiff=autodiff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.3\n",
    "lrd = 0.02\n",
    "seed = 1\n",
    "optimizer = SorelBatch(\n",
    "            train_obj,\n",
    "            lr=lr,\n",
    "            smooth_coef=0,\n",
    "            smoothing=\"l2\",\n",
    "            seed=seed,\n",
    "            length_epoch=train_obj.n,\n",
    "            lrdcon=lrd,\n",
    "            xlrcon=20,\n",
    "            batch_size = 64\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can get baseline optimizers by running the following code.\n",
    "# optim_cfg = {\n",
    "#         \"optimizer\": \"sgd\", # Options: 'sgd', 'lsvrg_batch', 'prospect_batch'\n",
    "#         \"lr\": lr,\n",
    "#         \"epoch_len\": train_obj.n, # Used as an update interval for LSVRG, and otherwise is simply a logging interval for other methods.\n",
    "#         \"shift_cost\": 0,\n",
    "#     }\n",
    "\n",
    "# optimizer = get_optimizer(optim_cfg, train_obj, seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L-BGFS sucsess! Minimum loss on amazon: 1.36986277\n"
     ]
    }
   ],
   "source": [
    "# Get the optimal solution by L-BFGS\n",
    "train_obj_ = get_objective(model_cfg, X_train, y_train)\n",
    "\n",
    "# # Define function and Jacobian oracles.\n",
    "def fun(w):\n",
    "    return train_obj_.get_batch_loss(torch.tensor(w, dtype=torch.float64)).item()\n",
    "\n",
    "def jac(w):\n",
    "    return (\n",
    "        train_obj_.get_batch_subgrad(\n",
    "            torch.tensor(w, dtype=torch.float64, requires_grad=True)\n",
    "        )\n",
    "        .detach()\n",
    "        .numpy()\n",
    "    )\n",
    "\n",
    "# Run optimizer.\n",
    "d = train_obj.d\n",
    "init = np.zeros((d,), dtype=np.float64)\n",
    "if model_cfg[\"n_class\"]:\n",
    "    init = np.zeros((model_cfg[\"n_class\"] * d,), dtype=np.float64)\n",
    "else:\n",
    "    init = np.zeros((d,), dtype=np.float64)\n",
    "output = minimize(fun, init, method=\"L-BFGS-B\", jac=jac)\n",
    "# output = minimize(fun, init, method='CG', jac=jac, options={'disp': True, 'eps': 1e0})\n",
    "if output.success:\n",
    "    print(f\"L-BGFS sucsess! Minimum loss on {dataset}: {output.fun:0.8f}\")\n",
    "else:\n",
    "    raise Exception(output.message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:17<00:00, 11.63it/s]\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 200\n",
    "train_losses = []\n",
    "epoch_len = optimizer.get_epoch_len()\n",
    "for epoch in tqdm(range(n_epochs)):\n",
    "\n",
    "    optimizer.start_epoch()\n",
    "    for _ in range(epoch_len):\n",
    "        optimizer.step()\n",
    "    optimizer.end_epoch()\n",
    "    train_losses.append(train_obj.get_batch_loss(optimizer.weights).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcIAAAHACAYAAAAx0GhOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJN0lEQVR4nO3deVhU1f8H8PdlGPZNFlkEBRdUFlFQcd/KBc0tNS0z0dQ096XMzNKs7GupWS5p5lJWmpZmWhlWbuCKoLhroqCAuLHvzP394c/Jy6CCDHNneb+eZ57HOWe48+Ey8uYu5xxBFEURREREJspM7gKIiIjkxCAkIiKTxiAkIiKTxiAkIiKTxiAkIiKTxiAkIiKTxiAkIiKTxiAkIiKTZi53AdqmUqmQkpICe3t7CIIgdzlERCQDURSRnZ0NLy8vmJk9/pjP6IIwJSUFPj4+cpdBRER6IDk5Gd7e3o99jdEFob29PYD737yDg4PM1RARkRyysrLg4+OjzoTHMbogfHA61MHBgUFIRGTiKnKJjDfLEBGRSWMQEhGRSWMQEhGRSTO6a4REpKm0tBTFxcVyl0GkVUqlEgqFosrbYRASGbmcnBxcv34dXIObjI0gCPD29oadnV2VtsMgJDJipaWluH79OmxsbODm5sZJJshoiKKIW7du4fr162jQoEGVjgwZhERGrLi4GKIows3NDdbW1nKXQ6RVbm5uuHr1KoqLi6sUhLxZhsgE8EiQjJG2Ptd6GYQ7d+5Ew4YN0aBBA6xZs0bucoiIyIjpXRCWlJRg2rRp+Pvvv3HixAn873//w927d+Uui4gMXKdOnTBlyhS5y8DevXshCAIyMjJkef+5c+eiadOmOn+fyMhI9OvXr9rf92noXRAePXoUgYGBqFWrFuzt7dGzZ0/s3r1b7rKISEcEQXjsIzIy8qm2+/PPP2P+/PlVqi09PR2vvfYaateuDUtLS3h4eKB79+44dOhQlbZbXQRBwPbt2yVtM2bMwF9//aXzWpYuXYr169ern+vLHyZANQTh/v370bt3b3h5eZX7QwCAFStWwM/PD1ZWVggLC8OBAwfUfSkpKahVq5b6ube3N27cuKHtMh9LFEVsPHwNKRn5On1fIgJSU1PVj88++wwODg6StqVLl0peX9Hxkc7OzhWagPlxBgwYgJMnT2LDhg24ePEiduzYgU6dOhnUWSs7Ozu4uLjo/H0dHR3h5OSk8/etCK0HYW5uLkJCQrBs2bJy+zdv3owpU6Zg9uzZiIuLQ/v27REREYGkpCQAKHesky4v9BcUl2L6lpN4Z/tpjNsYi4LiUp29N1F1UqlE3MkplPWhUj15LKOHh4f64ejoCEEQ1M8LCgrg5OSEH3/8EZ06dYKVlRU2btyIO3fu4MUXX4S3tzdsbGwQHByMH374QbLdskcgvr6++OijjzBy5EjY29ujdu3aWL169SPrysjIwMGDB/G///0PnTt3Rp06ddCyZUvMmjULvXr1AgBcvXoVgiAgPj5e8nWCIGDv3r2S7UVHRyMkJARWVlYIDw9HQkKCpP+nn35CYGAgLC0t4evri0WLFkn6fX19MX/+fLz00kuws7ODl5cXvvjiC0k/APTv3x+CIKifP+qU5UcffQR3d3c4OTlh3rx5KCkpwRtvvAFnZ2d4e3tj7dq1kvefOXMm/P39YWNjg7p162LOnDmP/aPk4VOjkZGR2LdvH5YuXao+0k9MTET9+vXx6aefSr7u9OnTMDMzw7///vvIbVeV1odPREREICIi4pH9ixcvxquvvopRo0YBAD777DPs3r0bK1euxIIFC1CrVi3JEeD169cRHh7+yO0VFhaisLBQ/TwrK+upa7+ZVYBRG44j4UYmAODk9UzM2X4aCwc24V13ZPDu5RUh7IM9stYQ+86zcLGzrPJ2Zs6ciUWLFmHdunWwtLREQUEBwsLCMHPmTDg4OGDXrl0YNmwY6tat+9jfH4sWLcL8+fPx9ttvY+vWrRg3bhw6dOiARo0aabzWzs4OdnZ22L59O1q1agVLy6p9H2+88QaWLl0KDw8PvP322+jTpw8uXrwIpVKJ2NhYvPDCC5g7dy4GDx6MmJgYvP7663BxcZGcGv7kk0/w9ttvY+7cudi9ezemTp2KRo0aoWvXrjh27Bhq1qyJdevWoUePHo8dXvD333/D29sb+/fvR3R0NF599VUcOnQIHTp0wJEjR7B582aMHTsWXbt2Va/3am9vj/Xr18PLywsJCQkYPXo07O3t8eabbz7xe1+6dCkuXryIoKAgvP/++wDuD4UYOXIk1q1bhxkzZqhfu3btWrRv3x716tV7yj39ZDq9RlhUVITY2Fh069ZN0t6tWzfExMQAAFq2bInTp0/jxo0byM7Oxm+//Ybu3bs/cpsLFiyAo6Oj+lGVRXmtlApkF0j/otkSex0bD1976m0SkfZNmTIFzz//PPz8/ODl5YVatWphxowZaNq0KerWrYuJEyeie/fu2LJly2O307NnT7z++uuoX78+Zs6cCVdXV40jtwfMzc2xfv16bNiwAU5OTmjbti3efvttnDp16qm+h/feew9du3ZFcHAwNmzYgJs3b2Lbtm0A7h8wPPPMM5gzZw78/f0RGRmJCRMm4JNPPpFso23btnjrrbfg7++PiRMnYuDAgViyZAmA+8ECAE5OTvDw8FA/L4+zszM+//xzNGzYECNHjkTDhg2Rl5eHt99+Gw0aNMCsWbNgYWGB6Oho9de88847aNOmDXx9fdG7d29Mnz4dP/74Y4W+d0dHR1hYWMDGxkZ9tK9QKDBixAhcuHABR48eBXD/tPfGjRsxcuTIiu/Yp6DTILx9+zZKS0vh7u4uaXd3d0daWhqA+x+2RYsWoXPnzmjWrBneeOONx57PnjVrFjIzM9WP5OTkp67P0VqJVcOaw8ZC+pfTvF/PIi7p3lNvl4i0q3nz5pLnpaWl+PDDD9GkSRO4uLjAzs4Of/75p/qSy6M0adJE/e8Hp2DT09Mf+foBAwYgJSUFO3bsQPfu3bF3716EhoZKbgKpqNatW6v/7ezsjIYNG+LcuXMAgHPnzqFt27aS17dt2xaXLl1CaWlpudt48PzBNiojMDAQZmb/xYG7uzuCg4PVzxUKBVxcXCT7ZuvWrWjXrh08PDxgZ2eHOXPmPHF/P4mnpyd69eqlPg27c+dOFBQUYNCgQVXa7pPIctdo2dOMoihK2h6cIrh8+TLGjBnz2G1ZWlqqF+HVxmK8DT3ssWhQiKStRCVi4g9xyMznpMVE+sDW1lbyfNGiRViyZAnefPNN/P3334iPj0f37t1RVFT02O0olUrJc0EQoFKpHvs1VlZW6Nq1K959913ExMQgMjIS7733HgCow+Thex0qM9n5g9+DZX8nlt1mRbZRGeXth8ftm8OHD2PIkCGIiIjAzp07ERcXh9mzZz9xf1fEqFGjsGnTJuTn52PdunUYPHgwbGxsqrzdx9HpFGuurq5QKBTqo78H0tPTNY4S5RQR7Ilxneph5d7/Ls5ev5eP2dsSsOylUBkrI3p6NWwsEPvOs7LXUB0OHDiAvn374uWXXwYAqFQqXLp0CY0bN66W93tYQECA+u74B6cfU1NT0axZMwCQ3DjzsMOHD6N27doAgHv37uHixYvqa5MBAQE4ePCg5PUxMTHw9/eXXOs7fPiwxjYfvr6pVColR5DaEh0djTp16mD27NnqtmvXKncJycLCotzaevbsCVtbW6xcuRK///479u/fX+V6n0SnQWhhYYGwsDBERUWhf//+6vaoqCj07dtXl6U80fSu/jh+9S6OXf3vlOjOU6kY1T4DTX2c5CuM6CmZmQlauVFFH9WvXx8//fQTYmJiUKNGDSxevBhpaWlaDcI7d+5g0KBBGDlyJJo0aQJ7e3scP34cCxcuVP/+sra2RqtWrfDxxx/D19cXt2/fxjvvvFPu9t5//324uLjA3d0ds2fPhqurq/quyunTp6NFixaYP38+Bg8ejEOHDmHZsmVYsWKFZBvR0dFYuHAh+vXrh6ioKGzZsgW7du1S9/v6+uKvv/5C27ZtYWlpiRo1amhlX9SvXx9JSUnYtGkTWrRogV27dqmvb1aUr68vjhw5gqtXr8LOzg7Ozs4wMzODQqFAZGQkZs2ahfr162uc/q0OWj81mpOTg/j4ePVfQYmJiYiPj1efO542bRrWrFmDtWvX4ty5c5g6dSqSkpIwduxYbZdSJeYKMywd0gyO1tLTA1/urb5beIno6cyZMwehoaHo3r07OnXqBA8PD63PYmJnZ4fw8HAsWbIEHTp0QFBQEObMmYPRo0dLhoutXbsWxcXFaN68OSZPnowPPvig3O19/PHHmDx5MsLCwpCamoodO3bAwuL+EXNoaCh+/PFHbNq0CUFBQXj33Xfx/vvva0wmMH36dMTGxqJZs2aYP38+Fi1aJLm5cNGiRYiKioKPj4/6CFUb+vbti6lTp2LChAlo2rQpYmJiMGfOnEptY8aMGVAoFAgICICbm5vk+uKrr76KoqKiar9JRk3Usn/++UcEoPEYPny4+jXLly8X69SpI1pYWIihoaHivn37qvy+y5YtExs3biz6+/uLAMTMzMwqb1MURXHl3stinZk71Q/ft3aKl25ma2XbRNUtPz9fPHv2rJifny93KaRlderUEZcsWSJ3GdXi4MGDorm5uZiWlvbY1z3u852ZmVnhLBBE0bhW68zKyoKjoyMyMzOrfOMMAGQXFKPNx38ju6BE3fZCc28sHBjymK8i0g8FBQVITExUz+RExsPX1xdTpkzRm2nKtKGwsBDJyckYM2YMPD098d133z329Y/7fFcmC/RurlF9Y2+lxCut60jafj2Ziryikkd8BRERPY0ffvgBDRs2RGZmJhYuXKiz92UQVsDwNr5QmP13S3J+cSn+OvfosUZERNXt6tWrRnU0CNyfeq20tBSxsbGSOaerG4OwAmraW6FNPemg/l9PpshUDRERaRODsIJ6N/GSPN978ZbGdGxERGR4GIQV1D3QA0rFf6dHi0pUiDp7U8aKiCrOyO6JIwKgvc+10QTh8uXLERAQgBYtWlTL9h1tlOjQQDpp7e+n0x7xaiL98GAWEm1MfUWkbx58rh+3skZF6HRmmeo0fvx4jB8/Xn3LbHXoGeyJv87/d5NMzOXbKCpRwcLcaP6eICNjbm4OGxsb3Lp1C0qlUjKxMpEhU6lUuHXrFmxsbGBuXrUoM5og1IWODaVHhLlFpYi9dg+t6+l+tWeiihAEAZ6enkhMTKz0XJBE+s7MzAy1a9eu8nqxDMJKcLWzRHAtR/XCvQCw7+ItBiHpNQsLCzRo0ICnR8noWFhYaOUsB4Owkjr6u2kE4VsRmqtZE+kTMzMzzixD9Ai8YFBJZU+PnkvNws2sApmqISKiqmIQVlIzHyfYW0kPpPdfvCVTNUREVFUMwkoyV5ihbT1XSdvhK3dlqoaIiKrKaIKwuscRPqxNfenNMYev3Kn29yQiouphNEE4fvx4nD17FseOHav292pVVxqENzLykXw3r9rfl4iItM9oglCXGtS0g7OthaTtSCJPjxIRGSIG4VMQBAHhfs6SNp4eJSIyTAzCp1T29CiDkIjIMDEIn1LZILx+Lx/X7/E6IRGRoWEQPqUGNe1Qw0YpaYu9dk+maoiI6GkxCJ+SmZmAsDrS64QnGIRERAbHaIJQl+MIHwirU0Py/DiDkIjI4BhNEOpyHOEDZYPwXGoWcgtLdPb+RERUdUYThHJo4u0Ic7P/1sFSicDJ5Az5CiIiokpjEFaBlVKBwFqOkjbeMENEZFgYhFUUVlt6ejQ2iUFIRGRIGIRVVPY64Ylr96BSiTJVQ0RElcUgrKLQOk6S51kFJfj3Vo48xRARUaUxCKvI09EatZysJW28TkhEZDgYhFoQWub0KIOQiMhwMAi1IKy2k+Q5b5ghIjIcDEItKDvV2pVbubibWyRTNUREVBlGE4RyTLH2QCNPe1grFZK2OB4VEhEZBKMJQjmmWHtAqTBDiA8H1hMRGSKjCUK5cQJuIiLDxCDUkrJBeOp6BopLVTJVQ0REFcUg1JJmPtIgLChW4VxqlkzVEBFRRTEItaSGrQXqutlK2nidkIhI/zEItajsBNwnkjLkKYSIiCqMQahF5U3ATURE+o1BqEVlg/BGRj5SM/NlqoaIiCqCQahF9dzs4GBlLmk7cS1DnmKIiKhCGIRaZGYmoJnGdUKeHiUi0mcMQi0re3qUd44SEek3owlCOecafVjZIDyTkomC4lKZqiEioicxmiCUc67Rh4X4OMFM+O95camIhBuZ8hVERESPZTRBqC/sLM3R0MNB0sbTo0RE+otBWA3C6jhJnnM8IRGR/mIQVgONgfVJ9yCKokzVEBHR4zAIq0FYbemK9bdzipB0N0+maoiI6HEYhNXAx9karnYWkrYjiXdlqoaIiB6HQVgNBEFASz/pUWHM5dsyVUNERI/DIKwmbeq5Sp5H/3uH1wmJiPQQg7CatK0vDcJb2YW4nJ4jUzVERPQoDMJq4utiAy9HK0lbNE+PEhHpHQZhNREEAW3qa54eJSIi/cIgrEZt6rlInh++cgclpSqZqiEiovIwCKtR2euE2QUlOHqVwyiIiPQJg7AauTtYIaiWdN7R3afTZKqGiIjKwyCsZj0CPSTPd5+5CZWKwyiIiPQFg7Ca9QiSBmFaVgFOXs+QpxgiItJgNEGoLwvzllW/pj3qudlK2v44w9OjRET6wmiCUF8W5i1P2aPCHfEpvHuUiEhPGE0Q6rOIIE/J89TMAuw5d1OmaoiI6GEMQh0IquWIEB8nSduGmGvyFENERBIMQh0Z3rqO5PmhK3dw8Wa2TNUQEdEDDEId6dXEEy620jUKV/xzWaZqiIjoAQahjliaK/Biy9qStu3xKUi4nilTRUREBDAIdWpkOz/YW5pL2j787SzXKSQikhGDUIecbS3weuf6krbDV+5ix8kUmSoiIiIGoY6NaOuLWk7WkrZ3fzmDm1kFMlVERGTaGIQ6ZqVUYHavxpK2zPxivLn1FOcgJSKSAYNQBj2DPdE7xEvStu/iLaw5eEWmioiITBeDUCbz+wbCzd5S0va/Py4g9hrXKyQi0iUGoUycbCyw5IWmEIT/2kpVIiZ+H4d7uUXyFUZEZGIYhDJq18AVE8vcRZqSWYAZW05ySAURkY4wCGU2+Vl/tKrrLGn763w61hxIlKkiIiLTwiCUmcJMwNIhzTSmX/vfH+dxIumeTFUREZkOBqEecHewwpLB0uuFJf9/vTAzr1i+woiITACDUE908HfDhDLXC29k5OPt7Qm8XkhEVI0YhHpk8jMNEO4nvV6461QqtsffkKkiIiLjxyDUI+YKMywd0gyO1kpJ+7vbz+D6vTyZqiIiMm4MQj3j4WiFBc8HS9qyC0sw7ceTKOUUbEREWscg1EM9gz0xINRb0nY08S5W7+cUbERE2mY0Qbh8+XIEBASgRYsWcpeiFXP7BMC7hnSVisVRF3D6BhfyJSLSJkE0slsSs7Ky4OjoiMzMTDg4OMhdTpUcTbyLIasP4eEzovXcbLFzYntYWyjkK4yISM9VJguM5ojQGLX0c8a4TvUkbf/eysXHv5+TqSIiIuPDINRzk5/xR3AtR0nbhkPX8M+FdJkqIiIyLgxCPWdhboYlg5vCSin9Ub259RTu5BTKVBURkfFgEBqA+jXtMLtXgKTtVnYhZv3MWWeIiKqKQWggXg6vjc4N3SRtf569iR+PJ8tUERGRcWAQGghBELBwYIjGKhXzfj2La3dyZaqKiMjwMQgNiJu9JT4e0ETSlldUiimb41FSqpKpKiIiw8YgNDBdA9zxYksfSVtcUgaW//OvTBURERk2BqEBeqdXAHxdbCRtn/99CXFcyJeIqNIYhAbI1tIcSwY3hcLsv5V8S1Uipm6OR25hiYyVEREZHgahgWpWuwYmdWkgabt6Jw/v/nKGQyqIiCqBQWjAxneuh2a1nSRtP524ju+PJslTEBGRAWIQGjBzhRk+G9wUtmUm4J674wyvFxIRVRCD0MDVcbHFwoEhkrbiUhGvf3cCtzkFGxHREzEIjUCvJp4Y06GupC01swATv4/j+EIioidgEBqJN7s3RKu6zpK2Q1fu4JM/L8hUERGRYWAQGglzhRmWvRQKDwcrSfuqfVfwe0KqTFUREek/BqERcbWzxIqXQ6FUCJL26VtO4nxalkxVERHpNwahkQmtXQPv9g6UtOUVlWLUhuO4m1skU1VERPqLQWiEXg6vjYFh3pK26/fyMW5jLIpKePMMEdHDGIRGSBAEfNAvSGOw/ZHEu5j7K2eeISJ6GIPQSFkpFVj1cpjGzTPfH0nCxsPXZKqKiEj/MAiNWE0HK6x+JQyW5tIf89xfzyLm8m2ZqiIi0i8MQiPXxNsJnw6SzjxTqhLx+vcnuLI9EREYhCahd4gXJnSuL2nLyCvGqxuOI7ugWKaqiIj0A4PQREzr6o+uAe6StsvpOZi8KR6lKt48Q0Smi0FoIszMBCwZ3BQN3e0l7X+fT8f8nWdlqoqISH4MQhNiZ2mONcObo4aNUtK+PuYq1h5MlKkqIiJ5MQhNjI+zDb58OQwWCumPfv6us/jzTJpMVRERyYdBaILC67pg4cAmkjZRBCZtisPJ5Ax5iiIikgmD0ET1a1YL07v6S9oKilV4dcNxJN/Nk6kqIiLdYxCasAld6mvMSXo7pxAj1x9DZj6HVRCRadDLIOzfvz9q1KiBgQMHyl2KURMEAR/1D0abei6S9kvpOZygm4hMhl4G4aRJk/DNN9/IXYZJsDA3w8qXw9Cgpp2kPebfO3h7WwIn6CYio6eXQdi5c2fY29s/+YWkFY7WSqwb0QKudpaS9q2x17Hoz4syVUVEpBuVDsL9+/ejd+/e8PLygiAI2L59u8ZrVqxYAT8/P1hZWSEsLAwHDhzQRq1Ujbxr2ODr4c1hpZR+JJb9cxnroznGkIiMV6WDMDc3FyEhIVi2bFm5/Zs3b8aUKVMwe/ZsxMXFoX379oiIiEBSUpL6NWFhYQgKCtJ4pKSkVPobKCwsRFZWluRBTyfExwlfvBgKM0HaPm/nWew8VfmfDRGRIRDEKlwEEgQB27ZtQ79+/dRt4eHhCA0NxcqVK9VtjRs3Rr9+/bBgwYIKb3vv3r1YtmwZtm7d+tjXzZ07F/PmzdNoz8zMhIODQ4Xfj/7zw9EkzPo5QdKmVAhYP6Il2tZ3lakqIqKKy8rKgqOjY4WyQKvXCIuKihAbG4tu3bpJ2rt164aYmBhtvpXarFmzkJmZqX4kJydXy/uYkhdb1sa0MmMMi0tFvPZtLE7fyJSpKiKi6qHVILx9+zZKS0vh7i5d5cDd3R1paRWfvqt79+4YNGgQfvvtN3h7e+PYsWOPfK2lpSUcHBwkD6q6iV3qY1irOpK2nMISRK47xnUMiciomFfHRgVBepFJFEWNtsfZvXu3tkuiShIEAXP7BOJObiF+S/jvj5jbOYV4Ze1RbB3bBm72lo/ZAhGRYdDqEaGrqysUCoXG0V96errGUSLpP4WZgMUvNEWrus6S9mt38jBi/VHkFJbIVBkRkfZoNQgtLCwQFhaGqKgoSXtUVBTatGmjzbciHbFSKrD6leZo7Ck95Xz6RhZe+/Y4CktKZaqMiEg7Kh2EOTk5iI+PR3x8PAAgMTER8fHx6uER06ZNw5o1a7B27VqcO3cOU6dORVJSEsaOHavVwstavnw5AgIC0KJFi2p9H1PkYKXEhhEt4ONsLWmPvnwH0348yRXuicigVXr4xN69e9G5c2eN9uHDh2P9+vUA7g+oX7hwIVJTUxEUFIQlS5agQ4cOWin4SSpzyyxVTuLtXAxcGYM7uUWS9hdb+uCj/sGVug5MRFSdKpMFVRpHqI8YhNXr1PUMDFl9GHlF0lOiYzrUxayIRgxDItILso0jJOPXxNsJq4ZprnC/ev8VrNj7r0xVERE9PQYhVVr7Bm74/MVmGlOxfbL7Ar45dFWWmoiInhaDkJ5KjyAPLBwYotH+7i9n8POJ6zJURET0dBiE9NQGhnljbu8AjfY3tp7C7jMVn0mIiEhORhOEHD4hj8i2fpheZl7SUpWIid/H4eCl2zJVRURUcbxrlKpMFEUs+P08Vu+/Imm3Virwzast0cLX+RFfSURUPXjXKOmUIAiYFdEIL7b0kbTnF5dixLpjiEu6J1NlRERPxiAkrRAEAR/0C0bvEC9Je05hCV5ZexQJ17l8ExHpJwYhac39SbpD0DVAOsF6dkEJhq09grMpWTJVRkT0aAxC0iqlwgzLXmqGzg3dJO0ZecV4+esjuHQzW6bKiIjKxyAkrbM0V2Dly2FoV99V0n43twgvrTmCK7dyZKqMiEgTg5CqhZVSga9eaY5wP+kdo7eyC/HSV0e4yj0R6Q2jCUKOI9Q/1hYKrI1sgbA6NSTtaVkFeOmrI7h+L0+myoiI/sNxhFTtsgqKMWzNEZwsc+dobWcbbH6tFTwdrR/xlURET4fjCEmvOFgp8c3IcASUWeU+6W4eXvrqCNIyC2SqjIiIQUg64mijxMZR4Wjobi9pT7ydiyGrDyE1M1+myojI1DEISWecbS2wcVQ46rnZStqv3snDkNWHkZLBMCQi3WMQkk652Vvih9GtULdMGF77/zC8wTAkIh1jEJLO1XSwwqbRrTSODJPu5mHI6kMMQyLSKQYhyaKmgxV+GNMK9WvaSdqT7+ZjyOpDHFpBRDrDICTZ1LS3wg+jW6FBuWF4GMl3GYZEVP2MJgg5oN4wudlb4vvRreDvLg3D6/cYhkSkGxxQT3rhdk4hhn51BBfKTMpdy8kam8a0go+zjUyVEZEh4oB6Mjiudpb4fnQ4GnlIxxneyMjH4FWHkHSHR4ZEVD0YhKQ3XOws8d0ozTBMySzA4NWHcPU2J+omIu1jEJJecbG7f82wcZnp2FIzCzBo1SFc5HqGRKRlDELSO862Fvh+lObcpLeyCzF41SGcvpH5iK8kIqo8BiHppRq2Fvh+dDiCaknD8F5eMV5cfRix1+7KVBkRGRsGIektJxsLfD+6lcZ6htmFJRj29VHEXL4tU2VEZEwYhKTX7i/h1BJt6rlI2vOKShG5/hj+Pn9TpsqIyFgwCEnv2VqaY21kC3RpVFPSXlSiwphvYrHrVKpMlRGRMTCaIOTMMsbNSqnAly+HoVewp6S9RCVi4g8nsDX2ukyVEZGh48wyZFBKVSJm/nSq3OCb3zcQw1r76r4oItI7nFmGjJbCTMDCAU3wSus6Gn1zfjmD5f9chpH9bUdE1YxBSAbHzEzAvD6BeK1jXY2+T3ZfwAe7zkGlYhgSUcUwCMkgCYKAt3o0wvSu/hp9Xx9MxIwtJ1FcqpKhMiIyNAxCMliCIGDiMw0wt3eARt/PcTfw2rexyC8qlaEyIjIkDEIyeJFt/bB0SFOYmwmS9r/Pp2PY10eQmV8sU2VEZAgYhGQU+jatha+GN4eVUvqRPn7tHgavOoT0rAKZKiMifccgJKPRuWFNfDcqHA5W5pL282nZGPBlDK7d4TJORKSJQUhGJayOM7aMbYOa9paS9uS7+Riw8hDOpHDlCiKSYhCS0WnoYY+fxrWBr4uNpP12TiGGrDrMybqJSIJBSEbJx9kGW8a2QaCXdEaJ7MISDF93FL/E35CpMiLSNwxCMlpu9pb4YUwrhPs5S9qLS0VM3hSPVfv+5Sw0RGQ8QchJt6k8DlZKbBjZEhFBHhp9C34/j3m/nkUpZ6EhMmmcdJtMQqlKxPydZ7E+5qpGX49AD3w2pCmslArdF0ZE1YKTbhOVoTAT8F7vAMzu2Vij748zaXh5zRHcyy2SoTIikhuDkEyGIAgY3aEulg5pCqVCOgvN8Wv3MODLGCTfzZOpOiKSC4OQTE7fprWwYWRL2JcZeH/lVi6eXxmD0zc41pDIlDAIySS1qeeKLWNbw8PBStJ+K7sQL6w6hL/P35SpMiLSNQYhmaxGHg7YNr4NGrrbS9rzikoxasNxrItOlKkyItIlBiGZNE9Ha/w4tjVa1ZWONVSJwLxfz+K9X06jhOsaEhk1BiGZPEfr+2MNn29WS6Nvw6FrGPXNcWQXcCknImPFICQCYGmuwKIXQspd8X7vhVsY9OUhpGTky1AZEVU3BiHR/3uw4v3SIU1hYS79r3E+LRt9l0fj1PUMeYojomrDICQqo2/TWvh+VDicbS0k7Q/uKN19Jk2myoioOjAIicrR3NcZ215vg7putpL2gmIVxm6M5YTdREaEQUj0CHVcbLFtXFu0rusiaRfF+xN2T99yEgXFpTJVR0TawiAkegxHm/t3lL7Q3Fuj7+cTN/DSV4eRnl0gQ2VEpC0MQqInsDA3w/8GNMGbPRpq9J1IykC/ZdGclo3IgBlNEHI9QqpOgiDg9U71sXpYGGwspMs1pWQWYNCXh/BbQqpM1RFRVXA9QqJKOpeahVEbjuNGOeMKpz7rj0nP1IcgCOV8JRHpCtcjJKpGjT0dsGNCW7T0ddboW7LnIiZ8H4f8It5EQ2QoGIRET8HFzhIbR4VjcHMfjb5dCakYtCoGqZmciYbIEDAIiZ6ShbkZPh4QjHefC4BZmTOhp29koc+yaJxIuidPcURUYQxCoioQBAEj2/lh3QjNhX5vZRdiyOrD2HI8WabqiKgiGIREWtDR3w3bx7eFn6t0JpqiEhXe2HoKc3ecQTGXcyLSSwxCIi2p52aH7a+3Rbv6rhp962Ou4uU1R3A7p1CGyojocRiERFrkaKPE+hEtMKKtr0bfkcS76PPFQSRc5+B7In3CICTSMnOFGd7rHYhPB4VoLOeUklmAgV/G4OcT12WqjojKYhASVZOBYd7Y8lpreDpaSdoLS1SY9uNJvP/rWZTwuiGR7BiERNUoxMcJOya0K3fw/droRLyy9iju5hbJUBkRPcAgJKpmbvb3B9+/0rqORl/Mv3fQ+4uDnLSbSEYMQiIdsDA3w/t9g7BwQBNYKKT/7W5k5GPglzH4Jf6GTNURmTYGIZEOvdDCB5tfawV3B0tJe0GxCpM3xePDXbxuSKRrDEIiHWtWuwZ+ndgOYXVqaPR9dSARw9cdxR2ONyTSGQYhkQxq2lvhh9GtMDS8tkZf9OU7eO6Lg4jjPKVEOsEgJJKJhbkZPuwfjAXPB0OpkM7anZpZgBdWHcK3h67CyJYMJdI7DEIimb3YsjY2jWmtcd2wuFTEnF/OYNqPJ7m+IVE1YhAS6YGwOjWwc2J7tKqrOd5wW9wN9F8RjcTbuTJURmT8GIREesLN3hIbXw3Hax3ravSdT8tGny8O4s8zaTJURmTcGIREesRcYYZZEY3x5cthsLOUrm+YXViCMd/G4n9/nOcQCyItYhAS6aEeQR7YMaEtGrrba/St3PsvXll7lEs6EWmJ0QTh8uXLERAQgBYtWshdCpFW1HWzw7bxbdC3qZdGX8y/d/Dc5wdxgkMsiKpMEI3s3uysrCw4OjoiMzMTDg4OcpdDVGWiKOKbQ9fwwa6zKC6V/ndVKgTMeS4Aw1rVgSAIj9gCkempTBYYzREhkbESBAHD2/hi05jW8HCQLulUXCri3V/OYNKmeOQUlshUIZFhYxASGYiwOjWwc1I7tK7rotH368kU9PniIM6nZclQGZFhYxASGRBXO0t8+2pLjOtUT6Pvyu1c9F0WjR+PJ8tQGZHhYhASGRhzhRlm9miEr15pDgcr6RCLwhIV3tx6CtN/PIm8Ip4qJaoIBiGRgeoa4I5dk9qjibejRt9PJ66j3/JoXE7PlqEyIsPCICQyYD7ONtgytjUi2/hq9F28mYM+y6K54C/REzAIiQycpbkCc/sEYvlLoRqz0eQVlWLypni8vS0BBcWcuJuoPAxCIiPRq4knfp3YDo09NcdMfX8kCc+viMFVTtxNpIFBSGRE/Fxtse31NnixpY9G39nULPT+4iB+T0iVoTIi/cUgJDIyVkoFFjzfBItfCIG1UiHpyy4swbjvTmDer2dQVMKJu4kABiGR0Xo+1Bs7JrRF/Zp2Gn3roq9i0KpDuH4vT4bKiPQLg5DIiDVwt8eOCW3xfLNaGn0nkzPQ6/OD2HP2pgyVEekPBiGRkbOxMMeiF0Lw8fPBsDSX/pfPzC/GqG+O44OdZ3mqlEwWg5DIBAiCgCEta2Pb623h52qr0b/mYCIGfRmDpDs8VUqmh0FIZEICvBywY0Jb9Ar21Og7eT0TvT4/gJ2nUmSojEg+DEIiE2NvpcSyl5phft9AWJQ5VZpdWIIJ38dxAD6ZFAYhkQkSBAHDWvti2+ttULecU6XfH0niXKVkMhiERCYs0MsRv05sV+5dpefTstH7i2hsOZ4MURRlqI5INxiERCbO1tIciwc3xaeDNAfg5xeX4o2tpzDtx5PIKeSyTmScGIREBAAYGOaNXye2QyMPe42+bXE30PuLgziTkilDZUTVi0FIRGr1a9ph+/i2GBpeW6Mv8XYu+i+PwTeHrvJUKRkVBiERSVgpFfiwfzCWvxQK+zLLOhWVqvDuL2cwdmMsMvOKZaqQSLsYhERUrl5NPPHb5PYI8XHS6Nt95iZ6fn4Ax6/e1X1hRFrGICSiR/JxtsGW11pjdHs/jb4bGfl4YdUhLN1zCSWlnJ6NDBeDkIgey8LcDLN7BWBtZHPUsFFK+lQisGTPRbz41WHcyMiXqUKiqmEQElGFdGnkjt8nd0C4n7NG37Gr9xDx2X78xkV/yQAxCImowjwcrfD96FaY0c0fCjNB0pdVUILXvzuBmVtPIa+IYw7JcDAIiahSFGYCJnRpgB9faw3vGtYa/ZuPJ+O5zw/i9A2OOSTDwCAkoqcSVqcGfpvcHn2bemn0Xbmdi/4rorHmwBWoVBxzSPqNQUhET83BSonPBjfFokEhsLWQTs9WXCrig13nELn+GNKzC2SqkOjJGIREVCWCIGBAmDd2TWqPEG9Hjf79F2+h59ID+OdCugzVET0Zg5CItMLX1RZbxrbBuE71IEjvo8HtnCKMWHcM8349w3UOSe/oXRAmJyejU6dOCAgIQJMmTbBlyxa5SyKiCrIwN8PMHo2w8dVw1LS31OhfF30V/VfEcJ1D0iuCqGez56ampuLmzZto2rQp0tPTERoaigsXLsDWVnPx0PJkZWXB0dERmZmZcHBwqOZqiehR7uYW4c2tp7Dn3E2NPiulGd59LhAvtvSBUPbwkUgLKpMFendE6OnpiaZNmwIAatasCWdnZ9y9y/kMiQyNs60FvnolDPP7BsLSXPqrpqBYhbe3JWDcxhPIyCuSqUKi+yodhPv370fv3r3h5eUFQRCwfft2jdesWLECfn5+sLKyQlhYGA4cOPBUxR0/fhwqlQo+Pj5P9fVEJC9BEDCstS92TGgHf3c7jf4/zqShx2cHcOjfOzJUR3RfpYMwNzcXISEhWLZsWbn9mzdvxpQpUzB79mzExcWhffv2iIiIQFJSkvo1YWFhCAoK0nikpKSoX3Pnzh288sorWL169VN8W0SkTxp62GPHhHZ4pXUdjb60rAK8tOYwPtl9HkUlnLybdK9K1wgFQcC2bdvQr18/dVt4eDhCQ0OxcuVKdVvjxo3Rr18/LFiwoELbLSwsRNeuXTF69GgMGzbsia8tLCxUP8/KyoKPjw+vERLpqaizN/Hm1pO4V856hsG1HLFkcFPUr6l59EhUGbJdIywqKkJsbCy6desmae/WrRtiYmIqtA1RFBEZGYkuXbo8MQQBYMGCBXB0dFQ/eBqVSL91Dbg/eXebei4afQk3MvHcFwfw7aGr0LP7+MiIaTUIb9++jdLSUri7u0va3d3dkZaWVqFtREdHY/Pmzdi+fTuaNm2Kpk2bIiEh4ZGvnzVrFjIzM9WP5OTkKn0PRFT9PBytsPHVcMzs0QjmZSbvLihWYc4vZzCCM9KQjphXx0bL3g4timKFb5Fu164dVKqKXyewtLSEpaXmeCUi0m9mZgLGdaqH9g1cMXlTHP69lSvp33vhFnp8dgALng9G90APmaokU6DVI0JXV1coFAqNo7/09HSNo0QiIgAIquWInRPbY3g5N9LczS3Ca9/GYubWU8gt5NJOVD20GoQWFhYICwtDVFSUpD0qKgpt2rTR5lsRkRGxtlBgXt8grB/RAm7lzEiz+Xgyen5+ALHX7slQHRm7SgdhTk4O4uPjER8fDwBITExEfHy8enjEtGnTsGbNGqxduxbnzp3D1KlTkZSUhLFjx2q1cCIyPp0a1sTuKR3QPVDzDNK1O3kY9GUMFkddRHEph1mQ9lR6+MTevXvRuXNnjfbhw4dj/fr1AO4PqF+4cCFSU1MRFBSEJUuWoEOHDlop+FGWL1+O5cuXo7S0FBcvXuTwCSIDJooitsRex7wdZ5BbpDlJd4iPEz4b3BR+rhWbepFMT2WGT+jdXKNVxblGiYxH0p08TP0xvtxTotZKBeY8F8D5SqlcBj3XKBHRA7VdbLB5TCtM7+qvMcwiv7gUb29LwKgNxznMgqqEQUhEes1cYYaJzzTAT+PaoG45p0L/Op+Obkv2Y+eplHK+mujJGIREZBBCfJywc1I7DA2vrdGXkVeMCd/HYeIPcVzNgiqNQUhEBsPGwhwf9g/G2sjmcLXTHGbx68kUdFuyH/9cSJehOjJURhOEy5cvR0BAAFq0aCF3KURUzbo0csefUzugZ7DmjDPp2YUYse4YZv18CjkchE8VwLtGichgiaKIHSdT8O4vZ5CZr7mahY+zNT4dGILwupoTfJNx412jRGQSBEFA36a18OfUDujo76bRn3w3H0O+Ooz5O8+ioFhzPCIRwCAkIiPg7mCF9SNa4KP+wbCxUEj6RBH4+mAinvviIE5dz5CnQNJrDEIiMgqCIOCl8Nr4Y3IHtPR11ui/nJ6D/is4RRtpYhASkVGp7WKDH8a0wuyejWFhLv0VV6oS8flfl9B/RTQu3syWqULSNwxCIjI6CjMBozvUxa6J7RBcy1Gj//SNLDz3xUGs3v8vSlVGdb8gPQUGIREZrQbu9vj59TaY8mwDjSnaikpU+Oi38xiy+hCu3cl9xBbIFBhNEHIcIRGVR6kww5Rn/bHt9bZoUNNOo//Y1XuIWHoA3x6+BhWPDk0SxxESkckoKC7F4qiL+OrAFZT3m69tfRf8b0ATeNew0X1xpFUcR0hEVA4rpQJv92yMzWNaw8fZWqM/+vIddF+yHz8cTYKRHSPQYzAIicjktPRzxh+TO+Clcibwzi0qxayfE/DK2qNIyciXoTrSNQYhEZkkW0tzfNQ/GN+MbAkvRyuN/gOXbqP7kv348Vgyjw6NHIOQiExaB383/DG1AwY399Hoyy4swZs/ncKI9ceQlsnFf40Vg5CITJ6DlRL/G9gE60a0gIeD5tHh3gu30HXJPmyNvc6jQyPEICQi+n+dG9bE7qkdMDDMW6Mvu6AEM7acxKgNx3Ezi0eHxoRBSET0EEdrJT4dFIK1kc1R015z8d+/zqej25L92BbHo0NjYTRByAH1RKRNXRq5I2pqRzzfrJZGX2Z+MaZuPokx38YiPZtHh4aOA+qJiJ7gzzNpeHvbadzOKdToc7JRYl6fQPQJ8YIgCOV8NcmBA+qJiLSoW6AHoqZ2QN+mXhp9GXnFmLwpHuM2nig3KEn/MQiJiCqghq0Flg5phi9fDoWLrYVG/x9n0tBtyX7sOpUqQ3VUFQxCIqJK6BHkiT+ndkCvJp4afXdzizD++xMY/x2PDg0Jg5CIqJJc7Cyx/KVQLH8pFM7lHB3uSkhFtyX7seNkCu8sNQAMQiKip9Sryf2jw4ggD42+u7lFmPRD3P07SznuUK8xCImIqsDVzhIrhobi8xebwclGqdEfdfYmnl28D1uOc85SfcUgJCKqIkEQ0CfEC1FTO6JnsObRYVZBCd7YegqR647hBle00DsMQiIiLXGzt8SKoWFYMTQUrnaa1w73XbyF7kv247sj13h0qEeMJgg5swwR6YuewZ74c2pH9Ctn3GFOYQlmbzuNoWuOIOlOngzVUVmcWYaIqBrtOXsTs7cn4GaW5nAKa6UCM3s0xCutfWFmxllptIkzyxAR6YlnA9zx59SOeKG55ooW+cWlmPvrWQxefQhXbuXIUB0BDEIiomrnaK3EwoEh+GZkS9RystboP3b1HiKWHsCqff+iVGVUJ+kMAoOQiEhHOvi74Y8p7fFyq9oafYUlKiz4/TyeXxmDizezZajOdDEIiYh0yN5KiQ/6BeOH0a1Q29lGo/9kcgae+/wglv19CcWlKhkqND0MQiIiGbSu54I/prTHyLZ+KLt6U1GpCp/+eRF9l0XjTEqmPAWaEAYhEZFMbCzM8W7vAGx5rTXqutpq9J9NzULfZdFY/OcFFJaUylChaWAQEhHJrLmvM36b3B5jO9ZD2VEUJSoRn/99Gb2/OIiTyRmy1GfsGIRERHrASqnAWxGNsO31tmjobq/Rf/FmDvqviMaC386hoJhHh9rEICQi0iMhPk7YMbEtJj3TAOZlDg9VIrBq/xVELD2Aw1fuyFSh8WEQEhHpGUtzBaZ19ccvE9oi0EtzVpTE27kYsvowZm9LQHZBsQwVGhcGIRGRngr0csT28W3xRveGsFBo/rr+7kgSui3Zj7/P35ShOuNhNEHISbeJyBgpFWYY37k+dk1qh2a1nTT6UzMLMHL9cUzZFIe7uUW6L9AIcNJtIiIDUaoS8c2hq1j4xwXkl3PDjLOtBeb2CUTvJp4Qyg5ONDGcdJuIyAgpzASMaOuHP6d2QLv6rhr9d3OLMOmHOIz+5jjSMgtkqNAwMQiJiAyMj7MNvn21JRYObAIHK3ON/j3n0tF18T58fyQJKk7i/UQMQiIiAyQIAl5o7oM90zqie6C7Rn92YQne3paAl9YcxtXbuTJUaDgYhEREBqymgxVWDWuOlUND4WpnqdF/+Mpd9Fi6H1/tv4ISTuJdLgYhEZERiAj2xJ5pHTAgVHMB4IJiFT787RwGrIzB+bQsGarTbwxCIiIj4WRjgUUvhGDDIxYAPnk9E899fhCLoy5yEu+HMAiJiIxMR3837J7aAZFtfDWWeCpRifj8r0t47vODOJF0T54C9QyDkIjICNlZmmNun8D7Szy5aS7xdCk9BwNWxmD+zrPIKyqRoUL9wSAkIjJizX2d8duk9pjQuT4UZSbxFkXg64OJ6P7ZfkRfvi1ThfJjEBIRGTkrpQIzujfEjkdM4p18Nx9D1xzBzK2nkJlvepN4MwiJiExEoJcjfhnfFjN7NIKFueav/83Hk9F18T78cTpNhurkwyAkIjIh5gozjOtUD39Mbo+Wvs4a/enZhRi7MRavfWs607QxCImITFBdNztsGtMK8/sGwtZCodG/+8xNPLt4H749dNXop2ljEBIRmSgzMwHDWvviz2kd0amhm0Z/TmEJ5vxyBgO/jMHFm9kyVKgbRhOEXI+QiOjp1HKyxrrIFlg6pClcbC00+k8kZaDX5wew6M8LKChn+SdDx/UIiYhI7V5uET787Ry2xl4vt9/P1RYf9Q9G63ouOq6scrgeIRERPZUathb4dFAIvh8VDl8XG43+xNu5ePGrw5i59RQy8opkqFD7GIRERKShTX1X/DGlA8Z3rgdzM83V7jcfT8azi/dhx8kUGPqJRQYhERGVy0qpwBvdG+HXie3Q1MdJo/92ThEm/RCHEeuPIflunu4L1BIGIRERPVZjTwf8NK4N5vUpf6jF3gu30G3Jfqw5YJhrHjIIiYjoiRRmAoa38UXUtI54trG7Rn9+cSk+2HUO/VfE4PSNTBkqfHoMQiIiqjAvJ2t89UoYVg4NRU17S43+hBuZ6LPsIObuOIOsAsOYt5RBSERElSIIAiKCPRE1rSOGhtfW6FeJwPqYq3h2kWHcTMMgJCKip+JorcSH/YOxZWxr1K9pp9Gfnl2IST/EYdjXR5F4O1eGCiuGQUhERFXSwtcZuya1w4xu/rAsZ1WLg5dvo/uS/VgcdVEvZ6ZhEBIRUZVZmiswoUsDRE0tf97SolIVPv/rEnp8th/7L96SocJHYxASEZHW1HaxwbrIFlg5NBQeDlYa/Vfv5OGVtUcx/vsTerPME4OQiIi06sHNNHumd8Sodn5QlDMzza5TqXh28T6sPZgo+9hDTrpNRETV6lxqFmZvS8CJpIxy+wM8HfBh/yA0q11Da+/JSbeJiEhvNPZ0wNaxbfDx88FwslFq9J9NzcLzK2Mw6+cE3MvV/UTeDEIiIqp2ZmYChrSsjb+mdcSgMG+NflEEfjiahC6L9mLT0SSoVLo7WclTo0REpHNHE+/ine0JuHgzp9z+pj5OmN83CMHejk+1fZ4aJSIivdbSzxm7JrXHrIhGsFZqTuQdn5yBPssP4sdjydVeC4OQiIhkoVSY4bWO9fDX9I7oGeyh0W+tVKC9v2u118EgJCIiWXk5WWPF0DB8M7Il6rraqtsnPdMAno7W1f7+DEIiItILHfzd8PuU9nije0ME13LEyLZ+Onlf3ixDRER6R6USYVbOQPyK4s0yRERk0KoSgpV+L529UzVbvnw5AgIC0KJFC7lLISIiA8JTo0REZHR4apSIiKiCGIRERGTSGIRERGTSGIRERGTSGIRERGTSGIRERGTSGIRERGTSGIRERGTSGIRERGTSGIRERGTSzOUuQNsezBiXlZUlcyVERCSXBxlQkVlEjS4Is7OzAQA+Pj4yV0JERHLLzs6Go6PjY19jdJNuq1QqpKSkwN7eHoJQtbWsfHx8kJycbBCTd7Pe6mVo9QKGVzPrrV6GVi9QtZpFUUR2dja8vLxgZvb4q4BGd0RoZmYGb29vrW3PwcHBYD40AOutboZWL2B4NbPe6mVo9QJPX/OTjgQf4M0yRERk0hiERERk0hiEj2BpaYn33nsPlpaWcpdSIay3ehlavYDh1cx6q5eh1Qvormaju1mGiIioMnhESEREJo1BSEREJo1BSEREJo1BSEREJo1BWI4VK1bAz88PVlZWCAsLw4EDB+QuCQCwYMECtGjRAvb29qhZsyb69euHCxcuSF4TGRkJQRAkj1atWslS79y5czVq8fDwUPeLooi5c+fCy8sL1tbW6NSpE86cOSNLrQ/4+vpq1CwIAsaPHw9A/v27f/9+9O7dG15eXhAEAdu3b5f0V2SfFhYWYuLEiXB1dYWtrS369OmD69ev67ze4uJizJw5E8HBwbC1tYWXlxdeeeUVpKSkSLbRqVMnjX0+ZMgQndcLVOznr8v9W5Gay/s8C4KATz75RP0aXe3jivwOk+MzzCAsY/PmzZgyZQpmz56NuLg4tG/fHhEREUhKSpK7NOzbtw/jx4/H4cOHERUVhZKSEnTr1g25ubmS1/Xo0QOpqanqx2+//SZTxUBgYKCkloSEBHXfwoULsXjxYixbtgzHjh2Dh4cHunbtqp4vVg7Hjh2T1BsVFQUAGDRokPo1cu7f3NxchISEYNmyZeX2V2SfTpkyBdu2bcOmTZtw8OBB5OTk4LnnnkNpaalO683Ly8OJEycwZ84cnDhxAj///DMuXryIPn36aLx29OjRkn2+atUqrdf6pHofeNLPX5f7tyI1P1xramoq1q5dC0EQMGDAAMnrdLGPK/I7TJbPsEgSLVu2FMeOHStpa9SokfjWW2/JVNGjpaeniwDEffv2qduGDx8u9u3bV76iHvLee++JISEh5fapVCrRw8ND/Pjjj9VtBQUFoqOjo/jll1/qqMInmzx5slivXj1RpVKJoqhf+xeAuG3bNvXziuzTjIwMUalUips2bVK/5saNG6KZmZn4xx9/6LTe8hw9elQEIF67dk3d1rFjR3Hy5MnVWlt5yqv3ST9/OfevKFZsH/ft21fs0qWLpE2ufVz2d5hcn2EeET6kqKgIsbGx6Natm6S9W7duiImJkamqR8vMzAQAODs7S9r37t2LmjVrwt/fH6NHj0Z6eroc5QEALl26BC8vL/j5+WHIkCG4cuUKACAxMRFpaWmSfW1paYmOHTvqzb4uKirCxo0bMXLkSMkE7vq0fx9WkX0aGxuL4uJiyWu8vLwQFBSkF/s9MzMTgiDAyclJ0v7dd9/B1dUVgYGBmDFjhqxnDR7389f3/Xvz5k3s2rULr776qkafHPu47O8wuT7DRjfpdlXcvn0bpaWlcHd3l7S7u7sjLS1NpqrKJ4oipk2bhnbt2iEoKEjdHhERgUGDBqFOnTpITEzEnDlz0KVLF8TGxup8Ronw8HB888038Pf3x82bN/HBBx+gTZs2OHPmjHp/lrevr127ptM6H2X79u3IyMhAZGSkuk2f9m9ZFdmnaWlpsLCwQI0aNTReI/dnvKCgAG+99RZeeuklyQTLQ4cOhZ+fHzw8PHD69GnMmjULJ0+eVJ+21qUn/fz1ef8CwIYNG2Bvb4/nn39e0i7HPi7vd5hcn2EGYTnKLt8kimKVlnSqDhMmTMCpU6dw8OBBSfvgwYPV/w4KCkLz5s1Rp04d7Nq1S+PDX90iIiLU/w4ODkbr1q1Rr149bNiwQX2DgT7v66+//hoRERHw8vJSt+nT/n2Up9mncu/34uJiDBkyBCqVCitWrJD0jR49Wv3voKAgNGjQAM2bN8eJEycQGhqq0zqf9ucv9/59YO3atRg6dCisrKwk7XLs40f9DgN0/xnmqdGHuLq6QqFQaPxVkZ6ervEXipwmTpyIHTt24J9//nniklOenp6oU6cOLl26pKPqHs3W1hbBwcG4dOmS+u5Rfd3X165dw549ezBq1KjHvk6f9m9F9qmHhweKiopw7969R75G14qLi/HCCy8gMTERUVFRT1xuJzQ0FEqlUi/2edmfvz7u3wcOHDiACxcuPPEzDVT/Pn7U7zC5PsMMwodYWFggLCxM43RAVFQU2rRpI1NV/xFFERMmTMDPP/+Mv//+G35+fk/8mjt37iA5ORmenp46qPDxCgsLce7cOXh6eqpPwzy8r4uKirBv3z692Nfr1q1DzZo10atXr8e+Tp/2b0X2aVhYGJRKpeQ1qampOH36tCz7/UEIXrp0CXv27IGLi8sTv+bMmTMoLi7Wi31e9uevb/v3YV9//TXCwsIQEhLyxNdW1z5+0u8w2T7DT3WLjRHbtGmTqFQqxa+//lo8e/asOGXKFNHW1la8evWq3KWJ48aNEx0dHcW9e/eKqamp6kdeXp4oiqKYnZ0tTp8+XYyJiRETExPFf/75R2zdurVYq1YtMSsrS+f1Tp8+Xdy7d6945coV8fDhw+Jzzz0n2tvbq/flxx9/LDo6Ooo///yzmJCQIL744ouip6enLLU+rLS0VKxdu7Y4c+ZMSbs+7N/s7GwxLi5OjIuLEwGIixcvFuPi4tR3WVZkn44dO1b09vYW9+zZI544cULs0qWLGBISIpaUlOi03uLiYrFPnz6it7e3GB8fL/lMFxYWiqIoipcvXxbnzZsnHjt2TExMTBR37dolNmrUSGzWrJnO663oz1+X+/dJNT+QmZkp2tjYiCtXrtT4el3u4yf9DhNFeT7DDMJyLF++XKxTp45oYWEhhoaGSoYnyAlAuY9169aJoiiKeXl5Yrdu3UQ3NzdRqVSKtWvXFocPHy4mJSXJUu/gwYNFT09PUalUil5eXuLzzz8vnjlzRt2vUqnE9957T/Tw8BAtLS3FDh06iAkJCbLU+rDdu3eLAMQLFy5I2vVh//7zzz/lfgaGDx8uimLF9ml+fr44YcIE0dnZWbS2thafe+65avseHldvYmLiIz/T//zzjyiKopiUlCR26NBBdHZ2Fi0sLMR69eqJkyZNEu/cuaPzeiv689fl/n1SzQ+sWrVKtLa2FjMyMjS+Xpf7+Em/w0RRns8wl2EiIiKTxmuERERk0hiERERk0hiERERk0hiERERk0hiERERk0hiERERk0hiERERk0hiERCasvBXNiUwNg5BIJpGRkRAEQePRo0cPuUsjMilcholIRj169MC6deskbXKva0hkanhESCQjS0tLeHh4SB4PFhwVBAErV65EREQErK2t4efnhy1btki+PiEhAV26dIG1tTVcXFwwZswY5OTkSF6zdu1aBAYGwtLSEp6enpgwYYKk//bt2+jfvz9sbGzQoEED7NixQ9137949DB06FG5ubrC2tkaDBg00gpvI0DEIifTYnDlzMGDAAJw8eRIvv/wyXnzxRZw7dw4AkJeXhx49eqBGjRo4duwYtmzZgj179kiCbuXKlRg/fjzGjBmDhIQE7NixA/Xr15e8x7x58/DCCy/g1KlT6NmzJ4YOHYq7d++q3//s2bP4/fffce7cOaxcuRKurq662wFEuvDU03UTUZUMHz5cVCgUoq2treTx/vvvi6J4f6b+sWPHSr4mPDxcHDdunCiKorh69WqxRo0aYk5Ojrp/165dopmZmZiWliaKoih6eXmJs2fPfmQNAMR33nlH/TwnJ0cUBEH8/fffRVEUxd69e4sjRozQzjdMpKd4jZBIRp07d8bKlSslbc7Ozup/t27dWtLXunVrxMfHAwDOnTuHkJAQ2Nraqvvbtm0LlUqFCxcuQBAEpKSk4JlnnnlsDU2aNFH/29bWFvb29khPTwcAjBs3DgMGDMCJEyfQrVs39OvXT/YFZom0jUFIJCNbW1uNU5VPIggCgPurfT/4d3mvsba2rtD2lEqlxteqVCoAQEREBK5du4Zdu3Zhz549eOaZZzB+/Hh8+umnlaqZSJ/xGiGRHjt8+LDG80aNGgEAAgICEB8fj9zcXHV/dHQ0zMzM4O/vD3t7e/j6+uKvv/6qUg1ubm6IjIzExo0b8dlnn2H16tVV2h6RvuERIZGMCgsLkZaWJmkzNzdX35CyZcsWNG/eHO3atcN3332Ho0eP4uuvvwYADB06FO+99x6GDx+OuXPn4tatW5g4cSKGDRsGd3d3AMDcuXMxduxY1KxZExEREcjOzkZ0dDQmTpxYofreffddhIWFITAwEIWFhdi5cycaN26sxT1AJD8GIZGM/vjjD3h6ekraGjZsiPPnzwO4f0fnpk2b8Prrr8PDwwPfffcdAgICAAA2NjbYvXs3Jk+ejBYtWsDGxgYDBgzA4sWL1dsaPnw4CgoKsGTJEsyYMQOurq4YOHBgheuzsLDArFmzcPXqVVhbW6N9+/bYtGmTFr5zIv0hiKIoyl0EEWkSBAHbtm1Dv3795C6FyKjxGiEREZk0BiEREZk0XiMk0lO8akGkGzwiJCIik8YgJCIik8YgJCIik8YgJCIik8YgJCIik8YgJCIik8YgJCIik8YgJCIik8YgJCIik/Z/Mhw7uyP1TggAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 500x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Display training suboptimality.\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "\n",
    "t = np.arange(len(train_losses))\n",
    "\n",
    "minimum_loss = output.fun\n",
    "train_losses = np.array(train_losses)\n",
    "subopt = (train_losses - minimum_loss) / (train_losses[0] - minimum_loss)\n",
    "\n",
    "ax.plot(np.arange(len(subopt)), subopt, linewidth=3, label=r\"Train Suboptimality\")\n",
    "ax.set_xlabel(\"Epochs\")\n",
    "\n",
    "ax.legend()\n",
    "ax.set_yscale(\"log\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute the Distribution Shift Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test worst gruop error: 0.8095238208770752\n"
     ]
    }
   ],
   "source": [
    "weights = optimizer.weights.clone().detach().view(-1, 5)\n",
    "logits = torch.matmul(X_val, weights)\n",
    "pred_probs = F.softmax(logits, dim=1)\n",
    "predictions = torch.argmax(pred_probs, dim=1)\n",
    "unique_groups = torch.unique(y_val)\n",
    "group_errors = []\n",
    "\n",
    "for group in unique_groups:\n",
    "    group_mask = (y_val == group)\n",
    "    group_predictions = predictions[group_mask]\n",
    "    group_labels = y_val[group_mask]\n",
    "    misclassified = (group_predictions != group_labels).float().sum()\n",
    "    group_error_rate = misclassified / group_mask.sum()\n",
    "    group_errors.append(group_error_rate.item())\n",
    "\n",
    "worst_group_error = max(group_errors)\n",
    "print(\"Test worst gruop error:\", worst_group_error)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lrm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
