{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regression on Blog Feedback Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['OMP_NUM_THREADS'] = '1'\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy\n",
    "\n",
    "import os,glob\n",
    "\n",
    "from sklearn.linear_model import Ridge, ElasticNet\n",
    "from sklearn.preprocessing import  MaxAbsScaler\n",
    "\n",
    "from scipy.linalg import norm\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import urllib.request\n",
    "\n",
    "import objectives"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BlogFeedback Database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://archive.ics.uci.edu/static/public/304/blogfeedback.zip\"\n",
    "data_path = 'REMOVED' \n",
    "# f = urllib.request.urlretrieve(url, data_path + \"blogfeedback.zip\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dataset looks great. It even has a description of what the columns mean: https://archive.ics.uci.edu/dataset/304/blogfeedback . One can think of federating the dataset by taking each blog individually. This is not directly in the dataset, but attributes 1 to 60 are attributes of the source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(data_path + \"blogData_train.csv\", header=None)\n",
    "csv_files = glob.glob(os.path.join(data_path, \"blogData_test*.csv\"))\n",
    "dfs_ood = []\n",
    "for path in csv_files:\n",
    "    # df_ood = pd.read_csv(data_path + \"blogData_test-2012.02.01.00_00.csv\", header=None)\n",
    "    dfs_ood.append(pd.read_csv(path, header=None))\n",
    "df_ood = pd.concat(dfs_ood)\n",
    "\n",
    "\n",
    "# df_train, df_test, df_y_train, df_y_test = train_test_split(df.iloc[:,:280], df[280], test_size=0.1, random_state=43)\n",
    "df_train, df_y_train = df.iloc[:,:280], df[280]\n",
    "# df_train, df_normalize, df_y_train, df_y_normalize = train_test_split(df_train, df_y_train, test_size=0.1, random_state=42)\n",
    "df_test_ood, df_y_test_ood = df_ood.iloc[:, :280], df_ood[280]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Federate the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of clients: 554\n"
     ]
    }
   ],
   "source": [
    "df_train_fl = df_train.copy()\n",
    "source_scaler = MaxAbsScaler()\n",
    "df_train_fl[280] = df_y_train\n",
    "df_train_fl.iloc[:,:280] = source_scaler.fit_transform(df_train_fl.iloc[:,:280])\n",
    "\n",
    "# # Split by unqiue combinations in column 1 to 50\n",
    "groups = df_train_fl.groupby([i for i in range(50)])\n",
    "dfs = [group for _,group in groups]\n",
    "print(f\"Number of clients: {len(dfs)}\")\n",
    "Xs = [df.iloc[:,:280] for df in dfs]\n",
    "ys = [df[280] for df in dfs]\n",
    "\n",
    "# Centralized dataset\n",
    "X_train = pd.concat(Xs)\n",
    "y_train = pd.concat(ys)\n",
    "\n",
    "# X_test = source_scaler.transform(df_test)\n",
    "# y_test = df_y_test.values\n",
    "X_test_ood = source_scaler.transform(df_test_ood)\n",
    "y_test_ood = df_y_test_ood.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAHLCAYAAADSuXIVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5YUlEQVR4nO3deVxVdeL/8fdF2VdxAVEQtyQ1N0zFLM0l8msuaWXL5DKOWpH7ZPFt1HKaQZ00l0HNUjQnvxmVfktLM9xJc8klyz1cUsHcQDFB4fz+mB/32x0W78WLl6Ov5+PB48HZ39wD9u5zzrnXYhiGIQAAABNyc3UAAACA0qLIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAC4SGRmpAQMGWKfXr18vi8Wi9evXuyyTvcyUtUOHDurQoYN1+tixY7JYLFq4cGGZH3vhwoWyWCw6duyYdV5kZKQee+yxMj+2ZK7zBJQWRQZwsqNHj2ro0KGqU6eOvLy8FBAQoAceeEAzZszQb7/95up4+vvf/67ly5e77Phffvml3njjDZcdv7Rmz559W8pPaZTnbEBZq+jqAMCdZOXKlXryySfl6empfv36qXHjxsrNzdXmzZv1yiuv6Mcff9S8efOK3Pahhx7Sb7/9Jg8PjzLN+Pe//11PPPGEevXqVabHKc6XX36pxMREl5WZWrVq6bfffpO7u7tD282ePVtVqlSxGUW7meeff15PP/20PD09HUzpmOKy3a7fKcCVKDKAk6Slpenpp59WrVq1tHbtWlWvXt26LC4uTkeOHNHKlSuL3d7NzU1eXl63I+pdzWKxlPnrnJ2dLV9fX1WoUEEVKlQo02OVhN8p3A24tAQ4yZQpU3TlyhXNnz/fpsQUqFevnkaMGFHs9sXdz/Ddd9/p0UcfVWBgoHx8fNS+fXulpqbarPPGG2/IYrHoyJEjGjBggIKCghQYGKiBAwfq6tWr1vUsFouys7O1aNEiWSwWWSyWm44w/PLLL+rVq5d8fX1VrVo1jRo1Sjk5OYXW27Rpk5588klFRETI09NT4eHhGjVqlM3ltAEDBigxMdGapeCrwNtvv622bduqcuXK8vb2VnR0tD755JMS8/3evHnzVLduXXl7e6tVq1batGlToXWKukcmPT1dAwcOVM2aNeXp6anq1aurZ8+e1ntbIiMj9eOPP2rDhg3WzAX33RTcB7Nhwwa99NJLqlatmmrWrGmz7Pf3yBT4+uuv1axZM3l5ealhw4b67LPPbJYXnNP/9J/7LClbcb9TycnJio6Olre3t6pUqaI//OEPOnXqlM06AwYMkJ+fn06dOqVevXrJz89PVatW1Z///Gfl5eUVcwaA248RGcBJvvjiC9WpU0dt27Z12j7Xrl2rrl27Kjo6WhMmTJCbm5uSkpLUsWNHbdq0Sa1atbJZ/6mnnlLt2rWVkJCg77//Xu+//76qVaumyZMnS5IWL16sP/3pT2rVqpWGDBkiSapbt26xx//tt9/UqVMnnThxQsOHD1dYWJgWL16stWvXFlo3OTlZV69e1YsvvqjKlStr27ZtmjVrln755RclJydLkoYOHarTp09rzZo1Wrx4caF9zJgxQz169NBzzz2n3NxcffTRR3ryySe1YsUKdevWrcTXav78+Ro6dKjatm2rkSNH6ueff1aPHj0UHBys8PDwErft06ePfvzxRw0bNkyRkZE6e/as1qxZoxMnTigyMlLTp0/XsGHD5Ofnp9dff12SFBISYrOPl156SVWrVtX48eOVnZ1d4vEOHz6svn376oUXXlD//v2VlJSkJ598UqtWrVKXLl1K3PY/2ZPt9xYuXKiBAwfq/vvvV0JCgjIyMjRjxgylpqZq165dCgoKsq6bl5en2NhYtW7dWm+//ba++eYbTZ06VXXr1tWLL77oUE6gzBgAbllmZqYhyejZs6fd29SqVcvo37+/dXrdunWGJGPdunWGYRhGfn6+Ub9+fSM2NtbIz8+3rnf16lWjdu3aRpcuXazzJkyYYEgy/vjHP9oc4/HHHzcqV65sM8/X19fmuCWZPn26Icn4+OOPrfOys7ONevXq2WQtyPWfEhISDIvFYhw/ftw6Ly4uzijun57/3Edubq7RuHFjo2PHjiXmzM3NNapVq2Y0a9bMyMnJsc6fN2+eIclo3769dV5aWpohyUhKSjIMwzAuXrxoSDL+8Y9/lHiMRo0a2eynQFJSkiHJaNeunXHjxo0il6WlpVnn1apVy5BkfPrpp9Z5mZmZRvXq1Y3mzZtb5xWc0+KO9/t9FpftP3+nCl6nxo0bG7/99pt1vRUrVhiSjPHjx1vn9e/f35BkTJw40WafzZs3N6KjowsdC3AVLi0BTpCVlSVJ8vf3d9o+d+/ercOHD+vZZ5/V+fPnde7cOZ07d07Z2dnq1KmTNm7cqPz8fJttXnjhBZvpBx98UOfPn7fmc9SXX36p6tWr64knnrDO8/HxsY7m/J63t7f1++zsbJ07d05t27aVYRjatWuXXcf7/T4uXryozMxMPfjgg/r+++9L3G7Hjh06e/asXnjhBZsbWwcMGKDAwMCbHtPDw0Pr16/XxYsX7cpZlMGDB9t9P0xYWJgef/xx63RAQID69eunXbt2KT09vdQZbqbgdXrppZds7p3p1q2boqKiiryHq6jfqZ9//rnMMgKO4tIS4AQBAQGSpMuXLzttn4cPH5Yk9e/fv9h1MjMzValSJet0RESEzfKCZRcvXrRmdMTx48dVr169QvdqNGjQoNC6J06c0Pjx4/X5558XKgSZmZl2HW/FihV66623tHv3bpv7cIq6V+Q/c0pS/fr1bea7u7urTp06JW7r6empyZMna8yYMQoJCVGbNm302GOPqV+/fgoNDbUrtyTVrl3b7nWLek3vueceSf++h8eR4zqi4HUq6vxFRUVp8+bNNvO8vLxUtWpVm3mVKlW6pcIHOBtFBnCCgIAAhYWFad++fU7bZ8Foyz/+8Q81a9asyHX8/PxsposbETAMw2m5ipKXl6cuXbrowoULevXVVxUVFSVfX1+dOnVKAwYMKDRyVJRNmzapR48eeuihhzR79mxVr15d7u7uSkpK0pIlS8o0/8iRI9W9e3ctX75cq1ev1rhx45SQkKC1a9eqefPmdu3j96NJzlBcebudN9q68okrwF4UGcBJHnvsMc2bN09btmxRTEzMLe+v4CbcgIAAde7c+Zb3V+Bmoxu/V6tWLe3bt0+GYdhsd/DgQZv1fvjhBx06dEiLFi1Sv379rPPXrFlj9/E//fRTeXl5afXq1Tbvu5KUlGRXTunfo1gdO3a0zr9+/brS0tLUtGnTm+6jbt26GjNmjMaMGaPDhw+rWbNmmjp1qv71r3+VmLs0jhw5Uug1PXTokKR/P4Uk/d9o2qVLl2xuwC0YVfk9e7MVvE4HDx60eZ0K5hUsB8yEe2QAJxk7dqx8fX31pz/9SRkZGYWWHz16VDNmzLB7f9HR0apbt67efvttXblypdDyX3/9tVQ5fX19denSJbvW/a//+i+dPn3a5hHoq1evFnpTv4L/c//9yI9hGEX+vL6+vpJUKEOFChVksVhsRhyOHTtm17sQt2zZUlWrVtXcuXOVm5trnb9w4cKb/qxXr17VtWvXbObVrVtX/v7+Npe3HHndbub06dNatmyZdTorK0sffPCBmjVrZr2sVFBkN27caF2v4NH5/2RvtpYtW6patWqaO3euzc/21Vdfaf/+/Td9MgwojxiRAZykbt26WrJkifr27at7773X5p19v/32WyUnJzv0rrBubm56//331bVrVzVq1EgDBw5UjRo1dOrUKa1bt04BAQH64osvHM4ZHR2tb775RtOmTVNYWJhq166t1q1bF7nu4MGD9c9//lP9+vXTzp07Vb16dS1evFg+Pj4260VFRalu3br685//rFOnTikgIECffvppkfdSREdHS5KGDx+u2NhYVahQQU8//bS6deumadOm6dFHH9Wzzz6rs2fPKjExUfXq1dPevXtL/Jnc3d311ltvaejQoerYsaP69u2rtLQ0JSUl3fQemUOHDqlTp0566qmn1LBhQ1WsWFHLli1TRkaGnn76aZvcc+bM0VtvvaV69eqpWrVqhUY17HXPPfdo0KBB2r59u0JCQrRgwQJlZGTYjD498sgjioiI0KBBg/TKK6+oQoUKWrBggapWraoTJ07Y7M/ebO7u7po8ebIGDhyo9u3b65lnnrE+fh0ZGalRo0aV6ucBXMqVj0wBd6JDhw4ZgwcPNiIjIw0PDw/D39/feOCBB4xZs2YZ165ds653s8evC+zatcvo3bu3UblyZcPT09OoVauW8dRTTxkpKSnWdQoe1f31119tti3qUd0DBw4YDz30kOHt7W1Iuumj2MePHzd69Ohh+Pj4GFWqVDFGjBhhrFq1qlDWn376yejcubPh5+dnVKlSxRg8eLCxZ88em0edDcMwbty4YQwbNsyoWrWqYbFYbB4xnj9/vlG/fn3D09PTiIqKMpKSkop9DLkos2fPNmrXrm14enoaLVu2NDZu3Gi0b9++xMevz507Z8TFxRlRUVGGr6+vERgYaLRu3drmkXPDMIz09HSjW7duhr+/v80j3QWv8fbt2wvlKe7x627duhmrV682mjRpYv1Zk5OTC22/c+dOo3Xr1oaHh4cRERFhTJs2rch9FpetuN+ppUuXGs2bNzc8PT2N4OBg47nnnjN++eUXm3X69+9v+Pr6FsrkyPkAbgeLYZTxXYAAAABlhHtkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAad3xb4iXn5+v06dPy9/f36lvMQ4AAMqOYRi6fPmywsLC5OZW/LjLHV9kTp8+rfDwcFfHAAAApXDy5EnVrFmz2OV3fJHx9/eX9O8XIiAgwMVpAACAPbKyshQeHm7973hx7vgiU3A5KSAggCIDAIDJ3Oy2EG72BQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAApkWRAQAAplXR1QHgGpGvrbyl7Y9N6uakJAAAlB4jMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLRcWmTeeOMNWSwWm6+oqCjr8mvXrikuLk6VK1eWn5+f+vTpo4yMDBcmBgAA5YnLR2QaNWqkM2fOWL82b95sXTZq1Ch98cUXSk5O1oYNG3T69Gn17t3bhWkBAEB5UtHlASpWVGhoaKH5mZmZmj9/vpYsWaKOHTtKkpKSknTvvfdq69atatOmze2OCgAAyhmXj8gcPnxYYWFhqlOnjp577jmdOHFCkrRz505dv35dnTt3tq4bFRWliIgIbdmypdj95eTkKCsry+YLAADcmVxaZFq3bq2FCxdq1apVmjNnjtLS0vTggw/q8uXLSk9Pl4eHh4KCgmy2CQkJUXp6erH7TEhIUGBgoPUrPDy8jH8KAADgKi69tNS1a1fr902aNFHr1q1Vq1Ytffzxx/L29i7VPuPj4zV69GjrdFZWFmUGAIA7lMsvLf1eUFCQ7rnnHh05ckShoaHKzc3VpUuXbNbJyMgo8p6aAp6engoICLD5AgAAd6ZyVWSuXLmio0ePqnr16oqOjpa7u7tSUlKsyw8ePKgTJ04oJibGhSkBAEB54dJLS3/+85/VvXt31apVS6dPn9aECRNUoUIFPfPMMwoMDNSgQYM0evRoBQcHKyAgQMOGDVNMTAxPLAEAAEkuLjK//PKLnnnmGZ0/f15Vq1ZVu3bttHXrVlWtWlWS9M4778jNzU19+vRRTk6OYmNjNXv2bFdGBgAA5YjFMAzD1SHKUlZWlgIDA5WZmcn9Mr8T+drKW9r+2KRuTkoCe93KOeN8ATAbe//7Xa7ukQEAAHAERQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJiWSz/9GoA58CGjAMorRmQAAIBpUWQAAIBpUWQAAIBpUWQAAIBpUWQAAIBp8dQSTOdWnqDh6RkAuLMwIgMAAEyLIgMAAEyLIgMAAEyLIgMAAEyLIgMAAEyLp5YAB/DEFACUL4zIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA06LIAAAA0yo3RWbSpEmyWCwaOXKkdd61a9cUFxenypUry8/PT3369FFGRobrQgIAgHKlXBSZ7du3691331WTJk1s5o8aNUpffPGFkpOTtWHDBp0+fVq9e/d2UUoAAFDeuLzIXLlyRc8995zee+89VapUyTo/MzNT8+fP17Rp09SxY0dFR0crKSlJ3377rbZu3erCxAAAoLxweZGJi4tTt27d1LlzZ5v5O3fu1PXr123mR0VFKSIiQlu2bCl2fzk5OcrKyrL5AgAAd6aKrjz4Rx99pO+//17bt28vtCw9PV0eHh4KCgqymR8SEqL09PRi95mQkKA333zT2VEBAEA55LIRmZMnT2rEiBH68MMP5eXl5bT9xsfHKzMz0/p18uRJp+0bAACULy4rMjt37tTZs2fVokULVaxYURUrVtSGDRs0c+ZMVaxYUSEhIcrNzdWlS5dstsvIyFBoaGix+/X09FRAQIDNFwAAuDO57NJSp06d9MMPP9jMGzhwoKKiovTqq68qPDxc7u7uSklJUZ8+fSRJBw8e1IkTJxQTE+OKyAAAoJxxWZHx9/dX48aNbeb5+vqqcuXK1vmDBg3S6NGjFRwcrICAAA0bNkwxMTFq06aNKyIDAIByxqU3+97MO++8Izc3N/Xp00c5OTmKjY3V7NmzXR0LAACUE+WqyKxfv95m2svLS4mJiUpMTHRNIAAAUK45fLPvyZMn9csvv1int23bppEjR2revHlODQYAAHAzDheZZ599VuvWrZP07/d66dKli7Zt26bXX39dEydOdHpAAACA4jhcZPbt26dWrVpJkj7++GM1btxY3377rT788EMtXLjQ2fkAAACK5XCRuX79ujw9PSVJ33zzjXr06CHp3x8fcObMGeemAwAAKIHDRaZRo0aaO3euNm3apDVr1ujRRx+VJJ0+fVqVK1d2ekAAAIDiOFxkJk+erHfffVcdOnTQM888o6ZNm0qSPv/8c+slJwAAgNvB4cevO3TooHPnzikrK0uVKlWyzh8yZIh8fX2dGg4AAKAkDo/IdOzYUZcvX7YpMZIUHBysvn37Oi0YAADAzThcZNavX6/c3NxC869du6ZNmzY5JRQAAIA97L60tHfvXuv3P/30k9LT063TeXl5WrVqlWrUqOHcdAAAACWwu8g0a9ZMFotFFotFHTt2LLTc29tbs2bNcmo4AACAkthdZNLS0mQYhurUqaNt27apatWq1mUeHh6qVq2aKlSoUCYhAQAAimJ3kalVq5YkKT8/v8zCACgbka+tdHUEACgTpfr068OHD2vdunU6e/ZsoWIzfvx4pwQDAAC4GYeLzHvvvacXX3xRVapUUWhoqCwWi3WZxWKhyAAAgNvG4SLz1ltv6W9/+5teffXVssgDAABgN4ffR+bixYt68sknyyILAACAQxwuMk8++aS+/vrrssgCAADgEIcvLdWrV0/jxo3T1q1bdd9998nd3d1m+fDhw50WDgBu1a08sXVsUjcnJgFQFhwuMvPmzZOfn582bNigDRs22CyzWCwUGQAAcNs4XGTS0tLKIgcAAIDDHL5HpkBubq4OHjyoGzduODMPAACA3RwuMlevXtWgQYPk4+OjRo0a6cSJE5KkYcOGadKkSU4PCAAAUByHi0x8fLz27Nmj9evXy8vLyzq/c+fOWrp0qVPDAQAAlMThe2SWL1+upUuXqk2bNjbv6tuoUSMdPXrUqeEAAABK4vCIzK+//qpq1aoVmp+dnW1TbAAAAMqaw0WmZcuWWrny/96XoaC8vP/++4qJiXFeMgAAgJtw+NLS3//+d3Xt2lU//fSTbty4oRkzZuinn37St99+W+h9ZQAAAMqSwyMy7dq10+7du3Xjxg3dd999+vrrr1WtWjVt2bJF0dHRZZERAACgSA6PyEhS3bp19d577zk7CwAAgEPsKjJZWVkKCAiwfl+SgvUAAADKml1FplKlSjpz5oyqVaumoKCgIp9OMgxDFotFeXl5Tg8JAABQFLuKzNq1axUcHCxJWrduXZkGAgAAsJddRaZ9+/ZFfg8AAOBKdhWZvXv32r3DJk2alDoMAACAI+wqMs2aNZPFYpFhGCWuxz0yAADgdrKryKSlpZV1DgAAAIfZVWRq1apV1jkA3MEiX1t585UAoBQcfmffhIQELViwoND8BQsWaPLkyU4JBQAAYA+Hi8y7776rqKioQvMbNWqkuXPnOiUUAACAPRwuMunp6apevXqh+VWrVtWZM2ecEgoAAMAeDheZ8PBwpaamFpqfmpqqsLAwp4QCAACwh8MfGjl48GCNHDlS169fV8eOHSVJKSkpGjt2rMaMGeP0gAAAAMVxuMi88sorOn/+vF566SXl5uZKkry8vPTqq68qPj7e6QEBAACK43CRsVgsmjx5ssaNG6f9+/fL29tb9evXl6enZ1nkAwAAKJbDRaaAn5+f7r//fmdmAQAAcIjDN/sCAACUFxQZAABgWhQZAABgWnYVmRYtWujixYuSpIkTJ+rq1atlGgoAAMAedhWZ/fv3Kzs7W5L05ptv6sqVK2UaCgAAwB52PbXUrFkzDRw4UO3atZNhGHr77bfl5+dX5Lrjx493akAAAIDi2DUis3DhQlWuXFkrVqyQxWLRV199pWXLlhX6Wr58uUMHnzNnjpo0aaKAgAAFBAQoJiZGX331lXX5tWvXFBcXp8qVK8vPz099+vRRRkaGQ8cAAAB3LrtGZBo0aKCPPvpIkuTm5qaUlBRVq1btlg9es2ZNTZo0SfXr15dhGFq0aJF69uypXbt2qVGjRho1apRWrlyp5ORkBQYG6uWXX1bv3r2L/KwnAABw93H4DfHy8/OddvDu3bvbTP/tb3/TnDlztHXrVtWsWVPz58/XkiVLrJ/plJSUpHvvvVdbt25VmzZtnJYDAACYU6ne2ffo0aOaPn269u/fL0lq2LChRowYobp165Y6SF5enpKTk5Wdna2YmBjt3LlT169fV+fOna3rREVFKSIiQlu2bCm2yOTk5CgnJ8c6nZWVVepMAACgfHP4fWRWr16thg0batu2bWrSpImaNGmi7777To0aNdKaNWscDvDDDz/Iz89Pnp6eeuGFF7Rs2TI1bNhQ6enp8vDwUFBQkM36ISEhSk9PL3Z/CQkJCgwMtH6Fh4c7nAkAAJiDwyMyr732mkaNGqVJkyYVmv/qq6+qS5cuDu2vQYMG2r17tzIzM/XJJ5+of//+2rBhg6OxrOLj4zV69GjrdFZWFmUGAIA7lMNFZv/+/fr4448Lzf/jH/+o6dOnOxzAw8ND9erVkyRFR0dr+/btmjFjhvr27avc3FxdunTJZlQmIyNDoaGhxe7P09OTT+IGAOAu4fClpapVq2r37t2F5u/evdspTzLl5+crJydH0dHRcnd3V0pKinXZwYMHdeLECcXExNzycQAAgPk5PCIzePBgDRkyRD///LPatm0rSUpNTdXkyZNtLunYIz4+Xl27dlVERIQuX76sJUuWaP369Vq9erUCAwM1aNAgjR49WsHBwQoICNCwYcMUExPDE0sAAEBSKYrMuHHj5O/vr6lTpyo+Pl6SFBYWpjfeeEPDhw93aF9nz55Vv379dObMGQUGBqpJkyZavXq19T6bd955R25uburTp49ycnIUGxur2bNnOxoZAADcoRwuMhaLRaNGjdKoUaN0+fJlSZK/v3+pDj5//vwSl3t5eSkxMVGJiYml2j8AALizlep9ZAqUtsAAAAA4g8M3+wIAAJQXFBkAAGBaFBkAAGBaDhWZ69evq1OnTjp8+HBZ5QEAALCbQ0XG3d1de/fuLassAAAADnH40tIf/vCHmz42DQAAcDs4/Pj1jRs3tGDBAn3zzTeKjo6Wr6+vzfJp06Y5LRwAAEBJHC4y+/btU4sWLSRJhw4dsllmsVickwoAAMAODheZdevWlUUOAAAAh5X68esjR45o9erV+u233yRJhmE4LRQAAIA9HB6ROX/+vJ566imtW7dOFotFhw8fVp06dTRo0CBVqlRJU6dOLYucgFNEvrbS1REAAE7k8IjMqFGj5O7urhMnTsjHx8c6v2/fvlq1apVTwwEAAJTE4RGZr7/+WqtXr1bNmjVt5tevX1/Hjx93WjAAAICbcXhEJjs722YkpsCFCxfk6enplFAAAAD2cLjIPPjgg/rggw+s0xaLRfn5+ZoyZYoefvhhp4YDAAAoicOXlqZMmaJOnTppx44dys3N1dixY/Xjjz/qwoULSk1NLYuMAAAARXJ4RKZx48Y6dOiQ2rVrp549eyo7O1u9e/fWrl27VLdu3bLICAAAUCSHR2QkKTAwUK+//rqzswB3NB79vvvcyjk/NqmbE5MAd65SFZmLFy9q/vz52r9/vySpYcOGGjhwoIKDg50aDgAAoCQOX1rauHGjIiMjNXPmTF28eFEXL17UzJkzVbt2bW3cuLEsMgIAABTJ4RGZuLg49e3bV3PmzFGFChUkSXl5eXrppZcUFxenH374wekhAQAAiuLwiMyRI0c0ZswYa4mRpAoVKmj06NE6cuSIU8MBAACUxOEi06JFC+u9Mb+3f/9+NW3a1CmhAAAA7GHXpaW9e/davx8+fLhGjBihI0eOqE2bNpKkrVu3KjExUZMmTSqblAAAAEWwq8g0a9ZMFotFhmFY540dO7bQes8++6z69u3rvHQAAAAlsKvIpKWllXUOAAAAh9lVZGrVqlXWOQAAABxWqjfEO336tDZv3qyzZ88qPz/fZtnw4cOdEgwAAOBmHC4yCxcu1NChQ+Xh4aHKlSvLYrFYl1ksFooMAAC4bRwuMuPGjdP48eMVHx8vNzeHn94GANPg87GA8s/hJnL16lU9/fTTlBgAAOByDreRQYMGKTk5uSyyAAAAOMThS0sJCQl67LHHtGrVKt13331yd3e3WT5t2jSnhQMAAChJqYrM6tWr1aBBA0kqdLMvAADA7eJwkZk6daoWLFigAQMGlEEcAAAA+zl8j4ynp6ceeOCBssgCAADgEIeLzIgRIzRr1qyyyAIAAOAQhy8tbdu2TWvXrtWKFSvUqFGjQjf7fvbZZ04LBwAAUBKHi0xQUJB69+5dFlkAAAAc4nCRSUpKKoscAAAADuPteQEAgGk5PCJTu3btEt8v5ueff76lQAAAAPZyuMiMHDnSZvr69evatWuXVq1apVdeecVZuQAAAG7K4SIzYsSIIucnJiZqx44dtxwIAADAXk67R6Zr16769NNPnbU7AACAm3Jakfnkk08UHBzsrN0BAADclMOXlpo3b25zs69hGEpPT9evv/6q2bNnOzUcAABASRwuMr169bKZdnNzU9WqVdWhQwdFRUU5KxcAAMBNOVxkJkyYUBY5AAAAHMYb4gEAANOye0TGzc2txDfCkySLxaIbN27ccigAAAB72F1kli1bVuyyLVu2aObMmcrPz3fo4AkJCfrss8904MABeXt7q23btpo8ebIaNGhgXefatWsaM2aMPvroI+Xk5Cg2NlazZ89WSEiIQ8cCAAB3HruLTM+ePQvNO3jwoF577TV98cUXeu655zRx4kSHDr5hwwbFxcXp/vvv140bN/Tf//3feuSRR/TTTz/J19dXkjRq1CitXLlSycnJCgwM1Msvv6zevXsrNTXVoWMBAIA7j8M3+0rS6dOnNWHCBC1atEixsbHavXu3Gjdu7PB+Vq1aZTO9cOFCVatWTTt37tRDDz2kzMxMzZ8/X0uWLFHHjh0l/fvTt++9915t3bpVbdq0KbTPnJwc5eTkWKezsrIczgUAAMzBoZt9MzMz9eqrr6pevXr68ccflZKSoi+++KJUJaa4/UuyvrHezp07df36dXXu3Nm6TlRUlCIiIrRly5Yi95GQkKDAwEDrV3h4uFOyAQCA8sfuIjNlyhTVqVNHK1as0P/8z//o22+/1YMPPui0IPn5+Ro5cqQeeOABazFKT0+Xh4eHgoKCbNYNCQlRenp6kfuJj49XZmam9evkyZNOywgAAMoXuy8tvfbaa/L29la9evW0aNEiLVq0qMj1Pvvss1IFiYuL0759+7R58+ZSbV/A09NTnp6et7QPAABgDnYXmX79+t308evSevnll7VixQpt3LhRNWvWtM4PDQ1Vbm6uLl26ZDMqk5GRodDQ0DLJAgAAzMPuIrNw4UKnH9wwDA0bNkzLli3T+vXrVbt2bZvl0dHRcnd3V0pKivr06SPp309KnThxQjExMU7PAwAAzKVUTy05S1xcnJYsWaL//d//lb+/v/W+l8DAQHl7eyswMFCDBg3S6NGjFRwcrICAAA0bNkwxMTFFPrEEAADuLi4tMnPmzJEkdejQwWZ+UlKSBgwYIEl655135Obmpj59+ti8IR4AAIBLi4xhGDddx8vLS4mJiUpMTLwNiQAAgJnwoZEAAMC0KDIAAMC0KDIAAMC0KDIAAMC0KDIAAMC0XPrUEgCgaJGvrbyl7Y9N6uakJED5xogMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwrYquDgAAuLNEvrbylrY/Nqmbk5LgbsCIDAAAMC2KDAAAMC2XFpmNGzeqe/fuCgsLk8Vi0fLly22WG4ah8ePHq3r16vL29lbnzp11+PBh14QFAADljkuLTHZ2tpo2barExMQil0+ZMkUzZ87U3Llz9d1338nX11exsbG6du3abU4KAADKI5fe7Nu1a1d17dq1yGWGYWj69On6y1/+op49e0qSPvjgA4WEhGj58uV6+umni9wuJydHOTk51umsrCznBwcAAOVCuX1qKS0tTenp6ercubN1XmBgoFq3bq0tW7YUW2QSEhL05ptv3q6YAFAu3cqTQzw1BDMptzf7pqenS5JCQkJs5oeEhFiXFSU+Pl6ZmZnWr5MnT5ZpTgAA4DrldkSmtDw9PeXp6enqGAAA4DYotyMyoaGhkqSMjAyb+RkZGdZlAADg7lZui0zt2rUVGhqqlJQU67ysrCx99913iomJcWEyAABQXrj00tKVK1d05MgR63RaWpp2796t4OBgRUREaOTIkXrrrbdUv3591a5dW+PGjVNYWJh69erlutAAAKDccGmR2bFjhx5++GHr9OjRoyVJ/fv318KFCzV27FhlZ2dryJAhunTpktq1a6dVq1bJy8vLVZEBAEA54tIi06FDBxmGUexyi8WiiRMnauLEibcxFQAAMItye48MAADAzVBkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAabn0s5YAAOVP5GsrXR0BsBsjMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLR4agkAUK7cylNTxyZ1c2ISmAEjMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLQoMgAAwLR4/BoAcMe41Q+85PFt82FEBgAAmBZFBgAAmBZFBgAAmBZFBgAAmBZFBgAAmBZPLaFU+FA3AEB5wIgMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLYoMAAAwLT40EgCA/+9WPhDXle7mD+NlRAYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWRQYAAJgWTy0BAGByrnzaytVPTDEiAwAATMsURSYxMVGRkZHy8vJS69attW3bNldHAgAA5UC5LzJLly7V6NGjNWHCBH3//fdq2rSpYmNjdfbsWVdHAwAALlbui8y0adM0ePBgDRw4UA0bNtTcuXPl4+OjBQsWuDoaAABwsXJ9s29ubq527typ+Ph46zw3Nzd17txZW7ZsKXKbnJwc5eTkWKczMzMlSVlZWWUb1mTyc6667Ni3ei5cmR0AYKus/vtasF/DMEpcr1wXmXPnzikvL08hISE280NCQnTgwIEit0lISNCbb75ZaH54eHiZZITjAqe7OgEAwFnK+t/0y5cvKzAwsNjl5brIlEZ8fLxGjx5tnc7Pz9eFCxdUuXJlWSwW3X///dq+fXuR2zqyLCsrS+Hh4Tp58qQCAgKc+0M4QUk/iyv36+j29q5/s/VKu7yo+eX53HPeHVuHv/my3y9/82XrTj7vhmHo8uXLCgsLK3Hbcl1kqlSpogoVKigjI8NmfkZGhkJDQ4vcxtPTU56enjbzgoKCrN9XqFCh2F/E0iwLCAgod7/YUsk/iyv36+j29q5/s/VKu7yk7crjuee8O7YOf/Nlv1/+5svWnX7eSxqJKVCub/b18PBQdHS0UlJSrPPy8/OVkpKimJiYUu0zLi7O6cvKo7LKe6v7dXR7e9e/2XqlXc55d85+XXXeb7YOf/Nlv1/+5svW3Xbei2IxbnYXjYstXbpU/fv317vvvqtWrVpp+vTp+vjjj3XgwIFC987cTllZWQoMDFRmZma5a+goW5z7uxPn/e7FuS/fyvWlJUnq27evfv31V40fP17p6elq1qyZVq1a5dISI/37EtaECRMKXcbCnY9zf3fivN+9OPflW7kfkQEAAChOub5HBgAAoCQUGQAAYFoUGQAAYFoUGQAAYFoUGQAAYFoUmTKyYsUKNWjQQPXr19f777/v6ji4TR5//HFVqlRJTzzxhKuj4DY6efKkOnTooIYNG6pJkyZKTk52dSTcBpcuXVLLli3VrFkzNW7cWO+9956rI92VePy6DNy4cUMNGzbUunXrFBgYqOjoaH377beqXLmyq6OhjK1fv16XL1/WokWL9Mknn7g6Dm6TM2fOKCMjQ82aNVN6erqio6N16NAh+fr6ujoaylBeXp5ycnLk4+Oj7OxsNW7cWDt27ODf+tuMEZkysG3bNjVq1Eg1atSQn5+funbtqq+//trVsXAbdOjQQf7+/q6OgdusevXqatasmSQpNDRUVapU0YULF1wbCmWuQoUK8vHxkSTl5OTIMAwxNnD7UWSKsHHjRnXv3l1hYWGyWCxavnx5oXUSExMVGRkpLy8vtW7dWtu2bbMuO336tGrUqGGdrlGjhk6dOnU7ouMW3Op5h3k589zv3LlTeXl5Cg8PL+PUuFXOOO+XLl1S06ZNVbNmTb3yyiuqUqXKbUqPAhSZImRnZ6tp06ZKTEwscvnSpUs1evRoTZgwQd9//72aNm2q2NhYnT179jYnhTNx3u9ezjr3Fy5cUL9+/TRv3rzbERu3yBnnPSgoSHv27FFaWpqWLFmijIyM2xUfBQyUSJKxbNkym3mtWrUy4uLirNN5eXlGWFiYkZCQYBiGYaSmphq9evWyLh8xYoTx4Ycf3pa8cI7SnPcC69atM/r06XM7YqIMlPbcX7t2zXjwwQeNDz744HZFhRPdyt98gRdffNFITk4uy5goAiMyDsrNzdXOnTvVuXNn6zw3Nzd17txZW7ZskSS1atVK+/bt06lTp3TlyhV99dVXio2NdVVkOIE95x13JnvOvWEYGjBggDp27Kjnn3/eVVHhRPac94yMDF2+fFmSlJmZqY0bN6pBgwYuyXs3K/effl3enDt3Tnl5eYU+fTskJEQHDhyQJFWsWFFTp07Vww8/rPz8fI0dO5a72E3OnvMuSZ07d9aePXuUnZ2tmjVrKjk5WTExMbc7LpzInnOfmpqqpUuXqkmTJtb7LBYvXqz77rvvdseFk9hz3o8fP64hQ4ZYb/IdNmwY59wFKDJlpEePHurRo4erY+A2++abb1wdAS7Qrl075efnuzoGbrNWrVpp9+7dro5x1+PSkoOqVKmiChUqFLqhKyMjQ6GhoS5KhbLGeb97ce7vTpx386DIOMjDw0PR0dFKSUmxzsvPz1dKSgqXEO5gnPe7F+f+7sR5Nw8uLRXhypUrOnLkiHU6LS1Nu3fvVnBwsCIiIjR69Gj1799fLVu2VKtWrTR9+nRlZ2dr4MCBLkyNW8V5v3tx7u9OnPc7hIufmiqX1q1bZ0gq9NW/f3/rOrNmzTIiIiIMDw8Po1WrVsbWrVtdFxhOwXm/e3Hu706c9zsDn7UEAABMi3tkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAACAaVFkAJQpi8Wi5cuXuzrGbfH7n/XYsWOyWCzavXu3SzMBdzqKDGByv/76q1588UVFRETI09NToaGhio2NVWpqqquj3dXCw8N15swZNW7c2Kn7vZuKIWAPPv0aMLk+ffooNzdXixYtUp06dZSRkaGUlBSdP3/e1dFMKzc3Vx4eHre0jwoVKig0NNRJiQAUhxEZwMQuXbqkTZs2afLkyXr44YdVq1YttWrVSvHx8erRo4d1vWnTpum+++6Tr6+vwsPD9dJLL+nKlSvW5QsXLlRQUJBWrFihBg0ayMfHR0888YSuXr2qRYsWKTIyUpUqVdLw4cOVl5dn3S4yMlJ//etf9cwzz8jX11c1atRQYmJiiZlPnjypp556SkFBQQoODlbPnj117Ngx6/L169erVatW8vX1VVBQkB544AEdP368yH0VXL756KOP1LZtW3l5ealx48basGGDzXr79u1T165d5efnp5CQED3//PM6d+6cdXmHDh308ssva+TIkapSpYpiY2OLzb9gwQI1atRInp6eql69ul5++eUSs/3+0pI9OYYPH66xY8cqODhYoaGheuONN6zLIyMjJUmPP/64LBaLdRq4m1FkABPz8/OTn5+fli9frpycnGLXc3Nz08yZM/Xjjz9q0aJFWrt2rcaOHWuzztWrVzVz5kx99NFHWrVqldavX6/HH39cX375pb788kstXrxY7777rj755BOb7f7xj3+oadOm2rVrl1577TWNGDFCa9asKTLH9evXFRsbK39/f23atEmpqany8/PTo48+qtzcXN24cUO9evVS+/bttXfvXm3ZskVDhgyRxWIp8XV45ZVXNGbMGO3atUsxMTHq3r27dUTq0qVL6tixo5o3b64dO3Zo1apVysjI0FNPPWWzj0WLFsnDw0OpqamaO3dukceZM2eO4uLiNGTIEP3www/6/PPPVa9evRKzFXAkh6+vr7777jtNmTJFEydOtL6e27dvlyQlJSXpzJkz1mngrmYAMLVPPvnEqFSpkuHl5WW0bdvWiI+PN/bs2VPiNsnJyUblypWt00lJSYYk48iRI9Z5Q4cONXx8fIzLly9b58XGxhpDhw61TteqVct49NFHbfbdt29fo2vXrtZpScayZcsMwzCMxYsXGw0aNDDy8/Oty3Nycgxvb29j9erVxvnz5w1Jxvr16+362dPS0gxJxqRJk6zzrl+/btSsWdOYPHmyYRiG8de//tV45JFHbLY7efKkIck4ePCgYRiG0b59e6N58+Y3PV5YWJjx+uuvF7v89z9rQbZdu3Y5lKNdu3Y269x///3Gq6++WuQxABgGIzKAyfXp00enT5/W559/rkcffVTr169XixYttHDhQus633zzjTp16qQaNWrI399fzz//vM6fP6+rV69a1/Hx8VHdunWt0yEhIYqMjJSfn5/NvLNnz9ocPyYmptD0/v37i8y6Z88eHTlyRP7+/tbRpODgYF27dk1Hjx5VcHCwBgwYoNjYWHXv3l0zZszQmTNnbvoa/D5DxYoV1bJlS2uGPXv2aN26ddbj+fn5KSoqSpJ09OhR63bR0dElHuPs2bM6ffq0OnXqdNM8RbE3R5MmTWy2q169eqHXHMD/4WZf4A7g5eWlLl26qEuXLho3bpz+9Kc/acKECRowYICOHTumxx57TC+++KL+9re/KTg4WJs3b9agQYOUm5srHx8fSZK7u7vNPi0WS5Hz8vPzS53zypUrio6O1ocfflhoWdWqVSX9+7LJ8OHDtWrVKi1dulR/+ctftGbNGrVp06bUx+zevbsmT55caFn16tWt3/v6+pa4H29v71Id39Eczn7NgTsdRQa4AzVs2ND6iO7OnTuVn5+vqVOnys3t34OwH3/8sdOOtXXr1kLT9957b5HrtmjRQkuXLlW1atUUEBBQ7D6bN2+u5s2bKz4+XjExMVqyZEmJRWbr1q166KGHJEk3btzQzp07rTfhtmjRQp9++qkiIyNVsWLp/8nz9/dXZGSkUlJS9PDDDzu8vbNyuLu729xwDdztuLQEmNj58+fVsWNH/etf/9LevXuVlpam5ORkTZkyRT179pQk1atXT9evX9esWbP0888/a/HixcXezFoaqampmjJlig4dOqTExEQlJydrxIgRRa773HPPqUqVKurZs6c2bdqktLQ0rV+/XsOHD9cvv/yitLQ0xcfHa8uWLTp+/Li+/vprHT58uNhiVCAxMVHLli3TgQMHFBcXp4sXL+qPf/yjJCkuLk4XLlzQM888o+3bt+vo0aNavXq1Bg4c6HAheOONNzR16lTNnDlThw8f1vfff69Zs2bZta2zchSUqfT0dF28eNGh/MCdiCIDmJifn59at26td955Rw899JAaN26scePGafDgwfrnP/8pSWratKmmTZumyZMnq3Hjxvrwww+VkJDgtAxjxozRjh071Lx5c7311luaNm1asY8v+/j4aOPGjYqIiFDv3r117733atCgQbp27ZoCAgLk4+OjAwcOqE+fPrrnnns0ZMgQxcXFaejQoSVmmDRpkiZNmqSmTZtq8+bN+vzzz1WlShVJUlhYmFJTU5WXl6dHHnlE9913n0aOHKmgoCDrCJW9+vfvr+nTp2v27Nlq1KiRHnvsMR0+fNiubZ2VY+rUqVqzZo3Cw8PVvHlzh/IDdyKLYRiGq0MAMKfIyEiNHDlSI0eOdMnxjx07ptq1a2vXrl1q1qyZSzIAcC1GZAAAgGlRZAAAgGlxaQkAAJgWIzIAAMC0KDIAAMC0KDIAAMC0KDIAAMC0KDIAAMC0KDIAAMC0KDIAAMC0KDIAAMC0/h+I6oFJARGm+wAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sizes = [x.shape[0] for x in Xs]\n",
    "bins = np.logspace(0, np.log10(max(sizes)), num=30)\n",
    "plt.hist(sizes,bins=bins)\n",
    "plt.xscale(\"log\")\n",
    "plt.title(\"Client data distribution\")\n",
    "plt.xlabel(\"Samples per client\")\n",
    "plt.ylabel(\"Number of clients\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training score: 0.29\n",
      "OOD Score: 0.2666\n",
      "Train loss: 2.8031e+07\n",
      "Norm x^*: 56.2\n"
     ]
    }
   ],
   "source": [
    "# alpha = 2e7\n",
    "alpha = 1e3\n",
    "reg = Ridge(fit_intercept=False, alpha=alpha, tol=1e-15)\n",
    "reg.fit(X_train, y_train)\n",
    "x_sol = reg.coef_\n",
    "# d = x_sol.shape[0]\n",
    "# x_sol = objectives.topK_prox(x_sol,0,int(0.05*d))\n",
    "# reg.coef_ = x_sol\n",
    "print(f\"Training score: {reg.score(X_train, y_train):.3g}\")\n",
    "# print(f\"In distribution score: {reg.score(X_test, y_test):.3g}\")\n",
    "print(f\"OOD Score: {reg.score(X_test_ood, y_test_ood):.4g}\")\n",
    "print(f\"Train loss: {objectives.obj_fun(X_train, y_train, x_sol, 2*alpha):.5g}\")\n",
    "print(f\"Norm x^*: {norm(x_sol):.3g}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "l1_ratio: 0.979\n",
      "Training score: 0.189\n",
      "OOD Score: 0.188\n",
      "Train loss: 3.07e+07\n",
      "Norm x_l1: 32.6\n",
      "Sparsity x_l1: 0.95\n",
      "Distance: 1.97e+03\n"
     ]
    }
   ],
   "source": [
    "#l1 regularisation\n",
    "alpha = 1e3\n",
    "reg_coeff = 3.2e4 #for 0.9 sparsity\n",
    "reg_coeff = 9e3 #for 0.8 sparsity\n",
    "reg_coeff = 9.5e4 #for 0.95 sparsity\n",
    "num_samples = X_train.shape[0]\n",
    "d = X_train.shape[1]\n",
    "\n",
    "a = reg_coeff/2/num_samples\n",
    "b = alpha/num_samples\n",
    "alpha_E =  a + b\n",
    "l1_ratio = a/(a+b)\n",
    "print(f\"l1_ratio: {l1_ratio:.3f}\")\n",
    "\n",
    "reg = ElasticNet(fit_intercept=False, alpha=alpha_E, l1_ratio=l1_ratio, tol=1e-10)\n",
    "reg.fit(X_train, y_train)\n",
    "x_l1 = reg.coef_\n",
    "print(f\"Training score: {reg.score(X_train, y_train):.3g}\")\n",
    "print(f\"OOD Score: {reg.score(X_test_ood, y_test_ood):.3g}\")\n",
    "print(f\"Train loss: {objectives.obj_fun(X_train, y_train, x_l1, 2*alpha):.3g}\")\n",
    "print(f\"Norm x_l1: {norm(x_l1):.3g}\")\n",
    "print(f\"Sparsity x_l1: {1- norm(x_l1.round(15),0)/d}\")\n",
    "print(f\"Distance: {norm(x_l1 - x_sol)**2:.3g}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L: 4.489e+02\n",
      "empirical mu: 2.000e+03\n",
      "Max L_i: 1.949e+04\n",
      "empirical mu FL: 3.610e+00\n",
      "Optimal theoretical step size centralized: 2.228e-03\n",
      "Optimal theoretical step size FL: 5.130e-05\n",
      "Condition number centralized: 2.245e-01\n",
      "Condition number FL: 5.400e+03\n",
      "Optimal theoretical p FL: 1.361e-02\n",
      "pT = 5e+02 (p = 0.0136)\n"
     ]
    }
   ],
   "source": [
    "mu = 2*alpha\n",
    "n = len(dfs)\n",
    "L = max(scipy.linalg.svdvals(X_train.T))**2 / n\n",
    "empirical_mu = min(scipy.linalg.svdvals(X_train))**2 + mu\n",
    "max_L = max(max(scipy.linalg.svdvals(X)) for X in Xs)**2\n",
    "min_mu_i = min(min(scipy.linalg.svdvals(X)) for X in Xs)**2 + mu/n\n",
    "central_step_size = 1/L\n",
    "fl_step_size = 1/max_L\n",
    "fl_p = np.sqrt(min_mu_i/max_L)\n",
    "print(f\"L: {L:.3e}\")\n",
    "print(f\"empirical mu: {empirical_mu:.3e}\")\n",
    "print(f\"Max L_i: {max_L:.3e}\")\n",
    "print(f\"empirical mu FL: {min_mu_i:.3e}\")\n",
    "print(f\"Optimal theoretical step size centralized: {1/L:.3e}\")\n",
    "print(f\"Optimal theoretical step size FL: {fl_step_size:.3e}\")\n",
    "print(f\"Condition number centralized: {L/mu:.3e}\")\n",
    "print(f\"Condition number FL: {max_L/min_mu_i:.3e}\")\n",
    "print(f\"Optimal theoretical p FL: {fl_p:.3e}\")\n",
    "\n",
    "eps = 1e-3\n",
    "p = fl_p\n",
    "print(f\"pT = {max(1/fl_step_size/mu, 1/p/p)* np.log(1/eps)*p:.1g} (p = {p:.3g})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(47157, 280)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
