{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num rows: 7214\n"
     ]
    }
   ],
   "source": [
    "raw_data = pd.read_csv('./compas-scores-two-years.csv')\n",
    "print('Num rows: %d' %len(raw_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num rows filtered: 6172\n"
     ]
    }
   ],
   "source": [
    "df = raw_data[((raw_data['days_b_screening_arrest'] <=30) & \n",
    "      (raw_data['days_b_screening_arrest'] >= -30) &\n",
    "      (raw_data['is_recid'] != -1) &\n",
    "      (raw_data['c_charge_degree'] != 'O') & \n",
    "      (raw_data['score_text'] != 'N/A')\n",
    "     )]\n",
    "\n",
    "print('Num rows filtered: %d' % len(df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from datetime import datetime\n",
    "from scipy.stats import pearsonr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def date_from_str(s):\n",
    "    return datetime.strptime(s, '%Y-%m-%d %H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Correlation btw stay length and COMPAS scores: 0.207\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/5c/cmnddjwj4w763gv0s6wcgzq80000gp/T/ipykernel_4734/2361523965.py:1: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df['length_of_stay'] = (df['c_jail_out'].apply(date_from_str) - df['c_jail_in'].apply(date_from_str)).dt.total_seconds()\n"
     ]
    }
   ],
   "source": [
    "df['length_of_stay'] = (df['c_jail_out'].apply(date_from_str) - df['c_jail_in'].apply(date_from_str)).dt.total_seconds()\n",
    "stay_score_corr = pearsonr(df['length_of_stay'], df['decile_score'])[0]\n",
    "print('Correlation btw stay length and COMPAS scores: %.3f' % stay_score_corr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After filtering we have the following demographic breakdown:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_crime = pd.get_dummies(df['c_charge_degree'],prefix='crimefactor',drop_first=True)\n",
    "df_age = pd.get_dummies(df['age_cat'],prefix='age')\n",
    "df_race = pd.get_dummies(df['race'],prefix='race')\n",
    "df_gender = pd.get_dummies(df['sex'],prefix='sex',drop_first=True)\n",
    "df_score = pd.get_dummies(df['score_text'] != 'Low',prefix='score_factor',drop_first=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lr = pd.concat([df_crime, df_age,df_race,df_gender,\n",
    "                   df['priors_count'],df['two_year_recid']\n",
    "                  ],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import numpy as np\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "import functools\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from copy import deepcopy\n",
    "from sklearn.preprocessing import normalize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = df_lr.values\n",
    "y = df_score.values.flatten()\n",
    "a = x[:,-3]\n",
    "x = np.delete(x,-3,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirichlet = np.random.dirichlet([1,2,3,4,5,6,7,8,9,10])\n",
    "num_samples_dirichlet = len(x)*dirichlet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples_dirichlet = np.array([  95.66940272,  103.47257568,   67.29848141,  598.75228282,\n",
    "        386.13166627,  737.45948646, 1241.56110773,  573.61883238,\n",
    "       1126.89500305, 1241.14116147])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "cumulative_samples_dirichlet = [int(num_samples_dirichlet[0])]\n",
    "for j in num_samples_dirichlet[1:(len(num_samples_dirichlet)-1)]:\n",
    "    cumulative_samples_dirichlet.append(cumulative_samples_dirichlet[-1]+int(j))\n",
    "cumulative_samples_dirichlet.append(len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: {'x': tensor([[0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1222, 0.0000, 0.0000, 0.1222, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9774, 0.1222],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.0000, 0.5774],\n",
       "          [0.0000, 0.0000, 0.0000, 0.1222, 0.0000, 0.0000, 0.1222, 0.0000, 0.0000,\n",
       "           0.0000, 0.9774, 0.1222],\n",
       "          [0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.0000, 0.0000],\n",
       "          [0.4472, 0.0000, 0.0000, 0.4472, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.4472, 0.0000, 0.0000, 0.4472, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000,\n",
       "           0.0000, 0.9820, 0.1091],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0762, 0.0000, 0.0000, 0.0762, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9912, 0.0762],\n",
       "          [0.0000, 0.0709, 0.0000, 0.0000, 0.0000, 0.0000, 0.0709, 0.0000, 0.0000,\n",
       "           0.0000, 0.9924, 0.0709],\n",
       "          [0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.0000, 0.1601, 0.0000, 0.0000,\n",
       "           0.0000, 0.9608, 0.1601],\n",
       "          [0.0000, 0.0000, 0.0762, 0.0000, 0.0762, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9912, 0.0762],\n",
       "          [0.1085, 0.1085, 0.0000, 0.0000, 0.0000, 0.0000, 0.1085, 0.0000, 0.0000,\n",
       "           0.0000, 0.9762, 0.1085],\n",
       "          [0.0000, 0.0000, 0.1400, 0.0000, 0.1400, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9802, 0.0000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.3015, 0.9045, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.2236, 0.2236, 0.0000, 0.0000, 0.0000, 0.0000, 0.2236, 0.0000, 0.0000,\n",
       "           0.0000, 0.8944, 0.2236],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.0000, 0.1231, 0.0000, 0.0000, 0.0000, 0.1231, 0.0000, 0.0000,\n",
       "           0.0000, 0.9847, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.4082, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0662, 0.0000, 0.0000, 0.0662, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9934, 0.0662],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.2236, 0.2236, 0.0000, 0.0000, 0.0000, 0.0000, 0.2236, 0.0000, 0.0000,\n",
       "           0.0000, 0.8944, 0.2236],\n",
       "          [0.1213, 0.1213, 0.0000, 0.0000, 0.1213, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9701, 0.1213],\n",
       "          [0.0000, 0.1601, 0.0000, 0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9608, 0.1601],\n",
       "          [0.0000, 0.0000, 0.0000, 0.3780, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.4472, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.2294, 0.0000, 0.2294, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.2887, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.0000, 0.5774],\n",
       "          [0.0000, 0.1091, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9820, 0.1091],\n",
       "          [0.0000, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.3536, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000,\n",
       "           0.0000, 0.7071, 0.3536],\n",
       "          [0.0000, 0.0664, 0.0000, 0.0000, 0.0664, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9956, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1387, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.3780, 0.0000, 0.3780, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.3780, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000]], dtype=torch.float64),\n",
       "  'a': array([0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 0, 1, 1, 1, 1]),\n",
       "  'y': tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1,\n",
       "          1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1,\n",
       "          0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 0, 0]),\n",
       "  'x_test': tensor([[0.0000, 0.0000, 0.0000, 0.2294, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.1387, 0.9707, 0.1387],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000,\n",
       "           0.0000, 0.9820, 0.1091],\n",
       "          [0.0000, 0.0475, 0.0000, 0.0000, 0.0475, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9966, 0.0475],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5000, 0.5000, 0.5000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1925, 0.0000, 0.0000, 0.1925, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9623, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.1857, 0.1857, 0.0000, 0.0000, 0.0000, 0.0000, 0.1857, 0.0000, 0.0000,\n",
       "           0.0000, 0.9285, 0.1857],\n",
       "          [0.0000, 0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.0000, 0.1601, 0.0000,\n",
       "           0.0000, 0.9608, 0.1601],\n",
       "          [0.2774, 0.0000, 0.0000, 0.2774, 0.2774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8321, 0.2774],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0709, 0.0000, 0.0000, 0.0709, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9924, 0.0709]], dtype=torch.float64),\n",
       "  'a_test': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
       "  'y_test': tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1])},\n",
       " 1: {'x': tensor([[0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.3536, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.3536, 0.7071, 0.3536],\n",
       "          [0.0000, 0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0356, 0.0000, 0.0000, 0.0356, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9981, 0.0356],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.1890, 0.9449, 0.1890],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0433, 0.0000, 0.0433, 0.0000, 0.0433, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9962, 0.0433],\n",
       "          [0.0000, 0.1091, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9820, 0.1091],\n",
       "          [0.5774, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.4472, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.4472, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0762, 0.0000, 0.0000, 0.0000, 0.0000, 0.0762, 0.0000, 0.0000,\n",
       "           0.0000, 0.9912, 0.0762],\n",
       "          [0.0000, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000, 0.2887, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0985, 0.0000, 0.0000, 0.0985, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9853, 0.0985],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.3536, 0.3536, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7071, 0.3536],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.1890, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.0000],\n",
       "          [0.0000, 0.1222, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1222, 0.0000,\n",
       "           0.0000, 0.9774, 0.1222],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0898, 0.0000, 0.0000, 0.0898, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9878, 0.0898],\n",
       "          [0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000, 0.4082, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.1925, 0.0000, 0.0000, 0.1925, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9623, 0.0000],\n",
       "          [0.2887, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.2887, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.1231, 0.0000, 0.1231, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9847, 0.0000],\n",
       "          [0.0000, 0.2887, 0.0000, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.1387, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.0000, 0.1091, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9820, 0.1091],\n",
       "          [0.4472, 0.0000, 0.4472, 0.0000, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.3780, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.3780, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.0664, 0.0000, 0.0000, 0.0664, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9956, 0.0000],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0985, 0.0000, 0.0000, 0.0000, 0.0000, 0.0985, 0.0000, 0.0000,\n",
       "           0.0000, 0.9853, 0.0985],\n",
       "          [0.1213, 0.1213, 0.0000, 0.0000, 0.1213, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9701, 0.1213],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.3536, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000,\n",
       "           0.0000, 0.7071, 0.3536],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.2294, 0.9177, 0.2294],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0434, 0.0000, 0.0000, 0.0434, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9972, 0.0434],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.1601, 0.0000, 0.0000,\n",
       "           0.0000, 0.9608, 0.1601],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.1213, 0.0000, 0.1213, 0.0000, 0.0000, 0.0000, 0.1213, 0.0000, 0.0000,\n",
       "           0.0000, 0.9701, 0.1213],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0524, 0.0000, 0.0000, 0.0524, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9959, 0.0524],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.3780, 0.3780, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0762, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0762, 0.9912, 0.0762],\n",
       "          [0.0000, 0.2887, 0.0000, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.0762, 0.0000, 0.0000, 0.0000, 0.0000, 0.0762, 0.0000, 0.0000,\n",
       "           0.0000, 0.9912, 0.0762],\n",
       "          [0.0000, 0.0000, 0.0000, 0.4082, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.0000, 0.3015, 0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.3780, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000]], dtype=torch.float64),\n",
       "  'a': array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1,\n",
       "         1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
       "  'y': tensor([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n",
       "          1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1,\n",
       "          1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,\n",
       "          1, 1, 1, 0, 1, 1, 0, 0, 1, 0]),\n",
       "  'x_test': tensor([[0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.3780, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000],\n",
       "          [0.3780, 0.3780, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000],\n",
       "          [0.2294, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.0000],\n",
       "          [0.3780, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.3780, 0.7559, 0.0000],\n",
       "          [0.3536, 0.3536, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7071, 0.3536],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.4472, 0.4472, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.1581, 0.1581, 0.0000, 0.0000, 0.1581, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9487, 0.1581],\n",
       "          [0.0000, 0.0898, 0.0000, 0.0000, 0.0898, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9878, 0.0898],\n",
       "          [0.0000, 0.0000, 0.0902, 0.0000, 0.0902, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9918, 0.0000],\n",
       "          [0.0000, 0.1622, 0.0000, 0.0000, 0.0000, 0.0000, 0.1622, 0.0000, 0.0000,\n",
       "           0.0000, 0.9733, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.3536, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000,\n",
       "           0.0000, 0.7071, 0.3536],\n",
       "          [0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.5774, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000]], dtype=torch.float64),\n",
       "  'a_test': array([0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1]),\n",
       "  'y_test': tensor([0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1])},\n",
       " 2: {'x': tensor([[0.0000, 0.0000, 0.0000, 0.7071, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.3015, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.1925, 0.0000, 0.0000, 0.0000, 0.0000, 0.1925, 0.0000, 0.0000,\n",
       "           0.0000, 0.9623, 0.0000],\n",
       "          [0.0894, 0.0894, 0.0000, 0.0000, 0.0894, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9839, 0.0894],\n",
       "          [0.0000, 0.0399, 0.0000, 0.0000, 0.0399, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9976, 0.0399],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.1091, 0.9820, 0.1091],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.0898, 0.0000, 0.0000, 0.0898, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9878, 0.0898],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.4472, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.3780, 0.7559, 0.3780],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.4472, 0.4472, 0.0000, 0.0000, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5774, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2357, 0.0000, 0.0000, 0.2357, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9428, 0.0000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.3015, 0.9045, 0.0000],\n",
       "          [0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2357, 0.0000, 0.0000, 0.2357, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9428, 0.0000],\n",
       "          [0.3780, 0.3780, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1601, 0.0000, 0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9608, 0.1601],\n",
       "          [0.5774, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.1213, 0.1213, 0.0000, 0.0000, 0.0000, 0.0000, 0.1213, 0.0000, 0.0000,\n",
       "           0.0000, 0.9701, 0.1213],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2357, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2357, 0.0000,\n",
       "           0.0000, 0.9428, 0.0000],\n",
       "          [0.0000, 0.3780, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.0000, 0.0711, 0.0000, 0.0000, 0.0711, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9949, 0.0000],\n",
       "          [0.0000, 0.0000, 0.2887, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.2774, 0.2774, 0.0000, 0.0000, 0.0000, 0.0000, 0.2774, 0.0000, 0.0000,\n",
       "           0.0000, 0.8321, 0.2774],\n",
       "          [0.0000, 0.1925, 0.0000, 0.0000, 0.1925, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9623, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.1890, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.0000],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000]], dtype=torch.float64),\n",
       "  'a': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         0, 1, 1, 1, 0, 1, 0, 1, 1]),\n",
       "  'y': tensor([0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0,\n",
       "          0, 1, 0, 0, 1]),\n",
       "  'x_test': tensor([[0.0000, 0.0000, 0.0000, 0.7071, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.2774, 0.0000, 0.0000, 0.2774, 0.2774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8321, 0.2774],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.1890, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.5000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2887, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000, 0.3015, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.2236, 0.0000, 0.0000, 0.2236, 0.2236, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8944, 0.2236]], dtype=torch.float64),\n",
       "  'a_test': array([1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1]),\n",
       "  'y_test': tensor([1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1])},\n",
       " 3: {'x': tensor([[0.0000, 0.5774, 0.0000,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.5774, 0.0000, 0.5774,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1231, 0.0000,  ..., 0.0000, 0.9847, 0.0000],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.0902,  ..., 0.0000, 0.9918, 0.0000],\n",
       "          [0.1085, 0.0000, 0.0000,  ..., 0.0000, 0.9762, 0.1085],\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.5000]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n",
       "         0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1]),\n",
       "  'y': tensor([0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n",
       "          1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1,\n",
       "          0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1,\n",
       "          1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1,\n",
       "          1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0,\n",
       "          0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,\n",
       "          0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,\n",
       "          1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "          1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0,\n",
       "          0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0,\n",
       "          0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1,\n",
       "          1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0,\n",
       "          0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
       "          1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,\n",
       "          1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0,\n",
       "          1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0]),\n",
       "  'x_test': tensor([[0.0000, 0.0000, 0.0356,  ..., 0.0000, 0.9981, 0.0356],\n",
       "          [0.0000, 0.1387, 0.0000,  ..., 0.0000, 0.9707, 0.1387],\n",
       "          [0.0000, 0.2294, 0.0000,  ..., 0.2294, 0.9177, 0.2294],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.5774,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.5000, 0.0000,  ..., 0.5000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.2887,  ..., 0.0000, 0.8660, 0.2887]],\n",
       "         dtype=torch.float64),\n",
       "  'a_test': array([1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1,\n",
       "         0, 1, 1, 0, 0, 1, 1, 1, 1, 1]),\n",
       "  'y_test': tensor([1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0,\n",
       "          1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1,\n",
       "          1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1,\n",
       "          1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,\n",
       "          0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0])},\n",
       " 4: {'x': tensor([[0.5774, 0.0000, 0.5774,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000,  ..., 0.5774, 0.0000, 0.5774],\n",
       "          [0.3536, 0.0000, 0.0000,  ..., 0.0000, 0.7071, 0.3536],\n",
       "          ...,\n",
       "          [0.0000, 0.4082, 0.0000,  ..., 0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.1222, 0.0000,  ..., 0.0000, 0.9774, 0.1222],\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.5000, 0.0000]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,\n",
       "         1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n",
       "         1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0,\n",
       "         0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0]),\n",
       "  'y': tensor([0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,\n",
       "          0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1,\n",
       "          0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0,\n",
       "          0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,\n",
       "          0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0,\n",
       "          1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,\n",
       "          1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,\n",
       "          0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1,\n",
       "          0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1,\n",
       "          1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,\n",
       "          1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0]),\n",
       "  'x_test': tensor([[0.0000, 0.0000, 0.0000, 0.3780, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.3780, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2887, 0.0000, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.1091, 0.0000, 0.0000, 0.1091, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9820, 0.1091],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.7071, 0.0000, 0.0000],\n",
       "          [0.0553, 0.0553, 0.0000, 0.0000, 0.0553, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9954, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.7071, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.0000, 0.0709, 0.0000, 0.0000, 0.0709, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9924, 0.0709],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1890, 0.0000, 0.0000, 0.1890, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9449, 0.1890],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.2774, 0.0000, 0.2774, 0.0000, 0.0000, 0.0000, 0.2774, 0.0000, 0.0000,\n",
       "           0.0000, 0.8321, 0.2774],\n",
       "          [0.1601, 0.1601, 0.0000, 0.0000, 0.0000, 0.0000, 0.1601, 0.0000, 0.0000,\n",
       "           0.0000, 0.9608, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5000, 0.0000, 0.5000],\n",
       "          [0.1387, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000, 0.3015, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.0000, 0.2357, 0.0000, 0.0000, 0.0000, 0.2357, 0.0000, 0.0000,\n",
       "           0.0000, 0.9428, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.1387, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.0000],\n",
       "          [0.4472, 0.0000, 0.0000, 0.4472, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.1925, 0.0000, 0.1925, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9623, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.4472, 0.0000, 0.0000, 0.4472, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.1601, 0.1601, 0.0000, 0.0000, 0.1601, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9608, 0.0000],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.0000, 0.0000, 0.5774, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.4082, 0.0000, 0.0000, 0.0000, 0.4082, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8165, 0.0000],\n",
       "          [0.0000, 0.1622, 0.0000, 0.0000, 0.0000, 0.0000, 0.1622, 0.0000, 0.0000,\n",
       "           0.0000, 0.9733, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000, 0.3015, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.0000, 0.1222, 0.0000, 0.0000, 0.0000, 0.1222, 0.0000, 0.0000,\n",
       "           0.0000, 0.9774, 0.1222],\n",
       "          [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.3536, 0.0000, 0.3536, 0.0000, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000,\n",
       "           0.0000, 0.7071, 0.3536],\n",
       "          [0.0000, 0.0621, 0.0000, 0.0000, 0.0621, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9942, 0.0621],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.1387, 0.0000, 0.0000, 0.1387, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9707, 0.1387],\n",
       "          [0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.5774, 0.0000],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.2294, 0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.0000],\n",
       "          [0.0000, 0.1222, 0.0000, 0.0000, 0.1222, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9774, 0.1222],\n",
       "          [0.5000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5000],\n",
       "          [0.0000, 0.2294, 0.0000, 0.0000, 0.0000, 0.0000, 0.2294, 0.0000, 0.0000,\n",
       "           0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.2887, 0.0000, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.4472, 0.4472, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.4472, 0.4472, 0.4472],\n",
       "          [0.0000, 0.0000, 0.0000, 0.3780, 0.0000, 0.0000, 0.0000, 0.3780, 0.0000,\n",
       "           0.0000, 0.7559, 0.3780],\n",
       "          [0.5774, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.0000,\n",
       "           0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.5774, 0.0000, 0.0000, 0.5774, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.0000, 0.5774],\n",
       "          [0.0000, 0.0000, 0.0000, 0.3015, 0.3015, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000, 0.2887, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000,\n",
       "           0.0000, 0.5000, 0.5000],\n",
       "          [0.0898, 0.0898, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0898, 0.0000,\n",
       "           0.0000, 0.9878, 0.0000],\n",
       "          [0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
       "           0.5000, 0.5000, 0.5000]], dtype=torch.float64),\n",
       "  'a_test': array([1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]),\n",
       "  'y_test': tensor([1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,\n",
       "          0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1,\n",
       "          0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1,\n",
       "          0, 1, 1, 1, 1, 0])},\n",
       " 5: {'x': tensor([[0.0000, 0.0553, 0.0000,  ..., 0.0000, 0.9954, 0.0553],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.1387, 0.0000,  ..., 0.0000, 0.9707, 0.1387],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.9948, 0.0585],\n",
       "          [0.0000, 0.5774, 0.0000,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.1091, 0.0000,  ..., 0.0000, 0.9820, 0.1091]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,\n",
       "         0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1,\n",
       "         0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
       "  'y': tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0,\n",
       "          1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,\n",
       "          0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0,\n",
       "          0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1,\n",
       "          1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0,\n",
       "          0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0,\n",
       "          0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,\n",
       "          0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0,\n",
       "          1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1,\n",
       "          0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0,\n",
       "          1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,\n",
       "          0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,\n",
       "          0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "          1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1,\n",
       "          1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1,\n",
       "          1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0,\n",
       "          1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0,\n",
       "          0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1,\n",
       "          0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n",
       "          0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0,\n",
       "          0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0,\n",
       "          0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,\n",
       "          1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0]),\n",
       "  'x_test': tensor([[0.0000, 0.2357, 0.0000,  ..., 0.0000, 0.9428, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.9177, 0.2294],\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.5000],\n",
       "          ...,\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.5000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5774,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.5000]],\n",
       "         dtype=torch.float64),\n",
       "  'a_test': array([0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
       "  'y_test': tensor([0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0,\n",
       "          0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,\n",
       "          1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0,\n",
       "          0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0,\n",
       "          1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1,\n",
       "          1, 0, 0, 0])},\n",
       " 6: {'x': tensor([[0.0000, 0.5774, 0.0000,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0902, 0.0000,  ..., 0.0000, 0.9918, 0.0000],\n",
       "          [0.5000, 0.0000, 0.5000,  ..., 0.0000, 0.5000, 0.0000],\n",
       "          ...,\n",
       "          [0.4472, 0.0000, 0.4472,  ..., 0.0000, 0.4472, 0.4472],\n",
       "          [0.2887, 0.2887, 0.0000,  ..., 0.0000, 0.8660, 0.0000],\n",
       "          [0.0000, 0.0985, 0.0000,  ..., 0.0000, 0.9853, 0.0985]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n",
       "         0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,\n",
       "         1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,\n",
       "         1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,\n",
       "         1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0,\n",
       "         0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n",
       "         0, 1]),\n",
       "  'y': tensor([0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "          1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1,\n",
       "          1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0,\n",
       "          1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1,\n",
       "          1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0,\n",
       "          1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,\n",
       "          0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,\n",
       "          0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0,\n",
       "          0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0,\n",
       "          1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1,\n",
       "          1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0,\n",
       "          0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0,\n",
       "          1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1,\n",
       "          1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1,\n",
       "          1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,\n",
       "          0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0,\n",
       "          1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,\n",
       "          0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,\n",
       "          1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,\n",
       "          0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1,\n",
       "          1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0,\n",
       "          0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0,\n",
       "          1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,\n",
       "          0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1,\n",
       "          0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1,\n",
       "          0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0,\n",
       "          1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0,\n",
       "          1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1,\n",
       "          0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0,\n",
       "          1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
       "          1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0,\n",
       "          1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0,\n",
       "          1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0,\n",
       "          0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0,\n",
       "          0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,\n",
       "          0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1,\n",
       "          1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0,\n",
       "          0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1,\n",
       "          0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0,\n",
       "          1, 1, 0, 1, 0, 0, 0, 1]),\n",
       "  'x_test': tensor([[0.3780, 0.3780, 0.0000,  ..., 0.0000, 0.7559, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.0985,  ..., 0.0000, 0.9853, 0.0985],\n",
       "          [0.0000, 0.2294, 0.0000,  ..., 0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.1400, 0.0000,  ..., 0.0000, 0.9802, 0.0000]],\n",
       "         dtype=torch.float64),\n",
       "  'a_test': array([0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0,\n",
       "         1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 0, 1]),\n",
       "  'y_test': tensor([1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1,\n",
       "          1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,\n",
       "          0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1,\n",
       "          1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,\n",
       "          0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,\n",
       "          1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,\n",
       "          1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1,\n",
       "          0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1,\n",
       "          1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,\n",
       "          1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1,\n",
       "          1, 0, 0, 1, 0, 1, 1, 0, 0])},\n",
       " 7: {'x': tensor([[0.0000, 0.0000, 0.5000,  ..., 0.0000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.1387, 0.0000,  ..., 0.0000, 0.9707, 0.1387],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5000, 0.5000],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.2357, 0.9428, 0.0000],\n",
       "          [0.5774, 0.0000, 0.5774,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.5774, 0.0000,  ..., 0.0000, 0.0000, 0.5774]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,\n",
       "         1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,\n",
       "         0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,\n",
       "         0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0,\n",
       "         0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1]),\n",
       "  'y': tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1,\n",
       "          0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,\n",
       "          0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,\n",
       "          0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0,\n",
       "          1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0,\n",
       "          0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1,\n",
       "          1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "          0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,\n",
       "          0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0,\n",
       "          1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1,\n",
       "          0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1,\n",
       "          1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,\n",
       "          0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,\n",
       "          0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0,\n",
       "          0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0,\n",
       "          1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1,\n",
       "          0, 0]),\n",
       "  'x_test': tensor([[0.0000, 0.1601, 0.0000,  ..., 0.0000, 0.9608, 0.1601],\n",
       "          [0.0000, 0.4082, 0.0000,  ..., 0.4082, 0.8165, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0711,  ..., 0.0000, 0.9949, 0.0000],\n",
       "          ...,\n",
       "          [0.3780, 0.3780, 0.0000,  ..., 0.0000, 0.7559, 0.0000],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.5774],\n",
       "          [0.4472, 0.0000, 0.0000,  ..., 0.0000, 0.4472, 0.4472]],\n",
       "         dtype=torch.float64),\n",
       "  'a_test': array([1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n",
       "         1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0,\n",
       "         0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1,\n",
       "         0, 1, 0, 1, 1]),\n",
       "  'y_test': tensor([0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,\n",
       "          0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0,\n",
       "          0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1,\n",
       "          1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0,\n",
       "          0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1])},\n",
       " 8: {'x': tensor([[0.0000, 0.1222, 0.0000,  ..., 0.0000, 0.9774, 0.1222],\n",
       "          [0.0000, 0.5774, 0.0000,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.5000],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.5774,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.1890, 0.0000,  ..., 0.0000, 0.9449, 0.1890],\n",
       "          [0.0000, 0.0765, 0.0000,  ..., 0.0000, 0.9941, 0.0000]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1,\n",
       "         1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
       "         1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1,\n",
       "         1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1,\n",
       "         1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n",
       "         1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,\n",
       "         1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,\n",
       "         1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1,\n",
       "         0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0,\n",
       "         1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1]),\n",
       "  'y': tensor([1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,\n",
       "          1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n",
       "          1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,\n",
       "          1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0,\n",
       "          0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,\n",
       "          1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0,\n",
       "          0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1,\n",
       "          0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1,\n",
       "          0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0,\n",
       "          0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "          0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,\n",
       "          1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1,\n",
       "          1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,\n",
       "          0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0,\n",
       "          0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,\n",
       "          1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0,\n",
       "          0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,\n",
       "          1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1,\n",
       "          0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0,\n",
       "          0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0,\n",
       "          0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,\n",
       "          0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0,\n",
       "          0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0,\n",
       "          0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1,\n",
       "          1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1,\n",
       "          0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1,\n",
       "          0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0,\n",
       "          0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0,\n",
       "          1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1,\n",
       "          0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1,\n",
       "          1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1,\n",
       "          1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,\n",
       "          1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,\n",
       "          0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]),\n",
       "  'x_test': tensor([[0.0000, 0.7071, 0.0000,  ..., 0.7071, 0.0000, 0.0000],\n",
       "          [0.5774, 0.0000, 0.5774,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          [0.5774, 0.0000, 0.5774,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          ...,\n",
       "          [0.0000, 0.5000, 0.0000,  ..., 0.5000, 0.5000, 0.5000],\n",
       "          [0.0000, 0.3780, 0.0000,  ..., 0.0000, 0.7559, 0.3780],\n",
       "          [0.0000, 0.0762, 0.0000,  ..., 0.0000, 0.9912, 0.0762]],\n",
       "         dtype=torch.float64),\n",
       "  'a_test': array([1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1]),\n",
       "  'y_test': tensor([0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,\n",
       "          1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,\n",
       "          1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1,\n",
       "          0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,\n",
       "          1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1,\n",
       "          0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1,\n",
       "          0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1,\n",
       "          0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,\n",
       "          1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0,\n",
       "          1, 0, 0, 1, 1, 0, 0, 0, 1, 1])},\n",
       " 9: {'x': tensor([[0.0000, 0.0000, 0.3015,  ..., 0.0000, 0.9045, 0.0000],\n",
       "          [0.0000, 0.1231, 0.0000,  ..., 0.0000, 0.9847, 0.0000],\n",
       "          [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.5000],\n",
       "          ...,\n",
       "          [0.0000, 0.0000, 0.4082,  ..., 0.0000, 0.8165, 0.0000],\n",
       "          [0.5000, 0.0000, 0.5000,  ..., 0.0000, 0.0000, 0.5000],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.5774]],\n",
       "         dtype=torch.float64),\n",
       "  'a': array([1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0,\n",
       "         1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0,\n",
       "         1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "         1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1,\n",
       "         1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1,\n",
       "         0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "         0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1,\n",
       "         0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1,\n",
       "         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0,\n",
       "         1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1]),\n",
       "  'y': tensor([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0,\n",
       "          0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,\n",
       "          1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,\n",
       "          1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1,\n",
       "          0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1,\n",
       "          1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1,\n",
       "          0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0,\n",
       "          0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,\n",
       "          1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "          1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0,\n",
       "          1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "          0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "          0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0,\n",
       "          1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1,\n",
       "          0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,\n",
       "          0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n",
       "          1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1,\n",
       "          1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0,\n",
       "          0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,\n",
       "          1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1,\n",
       "          0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0,\n",
       "          1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0,\n",
       "          1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n",
       "          0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1,\n",
       "          1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1,\n",
       "          1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1,\n",
       "          0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1,\n",
       "          1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,\n",
       "          1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0,\n",
       "          0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n",
       "          1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,\n",
       "          1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,\n",
       "          1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,\n",
       "          1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1,\n",
       "          1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0,\n",
       "          1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1]),\n",
       "  'x_test': tensor([[0.0000, 0.7071, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
       "          [0.0000, 0.0000, 0.5774,  ..., 0.0000, 0.5774, 0.0000],\n",
       "          [0.0000, 0.0000, 0.2887,  ..., 0.0000, 0.8660, 0.2887],\n",
       "          ...,\n",
       "          [0.0000, 0.2294, 0.0000,  ..., 0.0000, 0.9177, 0.2294],\n",
       "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.8660, 0.2887],\n",
       "          [0.0000, 0.5774, 0.0000,  ..., 0.0000, 0.5774, 0.0000]],\n",
       "         dtype=torch.float64),\n",
       "  'a_test': array([1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "         1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n",
       "         1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "         1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1]),\n",
       "  'y_test': tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0,\n",
       "          1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0,\n",
       "          1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,\n",
       "          0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,\n",
       "          1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1,\n",
       "          1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1,\n",
       "          0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1,\n",
       "          0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,\n",
       "          1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0,\n",
       "          0, 0, 1, 1, 1, 1, 0, 0, 1, 0])}}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_clients = {}\n",
    "def partition_data(x,y,a,partition_indices):\n",
    "    previous_index = 0\n",
    "    for c_id,i in enumerate(partition_indices):\n",
    "        client_x = x[previous_index:i]\n",
    "        client_a = a[previous_index:i]\n",
    "        client_y = y[previous_index:i]\n",
    "        X_train, X_test, y_train, y_test, a_train, a_test = train_test_split(\n",
    "            torch.tensor(normalize(np.nan_to_num(client_x,nan=-1))),torch.LongTensor(client_y),client_a, test_size=0.2, random_state=0)\n",
    "        data = {}\n",
    "        data['x'] = X_train\n",
    "        data['a'] = a_train\n",
    "        data['y'] = y_train\n",
    "        data['x_test'] = X_test\n",
    "        data['a_test'] = a_test\n",
    "        data['y_test'] = y_test\n",
    "        previous_index = i\n",
    "        all_clients[c_id] = data\n",
    "    return\n",
    "partition_data(x,y,a,cumulative_samples_dirichlet)\n",
    "all_clients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "protected_attributes = [0,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(LogisticRegression, self).__init__()\n",
    "        self.linear = torch.nn.Linear(input_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        outputs = self.linear(x)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sizes(lst):\n",
    "    sizes = []\n",
    "    for w in lst:\n",
    "        sizes.append(functools.reduce((lambda x, y: x*y), w.size()))\n",
    "    c = np.cumsum(sizes)\n",
    "    bounds = list(zip([0] + c[:-1].tolist(), c.tolist()))\n",
    "    return sizes, bounds\n",
    "\n",
    "def torch_to_numpy(lst, arr=None):\n",
    "    # lst: obtained either from list(net.parameters()) or from torch.autograd.grad\n",
    "    lst = list(lst)\n",
    "    sizes, bounds = get_sizes(lst)\n",
    "    if arr is None:\n",
    "        arr = np.zeros(sum(sizes))\n",
    "    else:\n",
    "        assert len(arr) == sum(sizes)\n",
    "    for bound, var in zip(bounds, lst):\n",
    "        arr[bound[0]: bound[1]] = var.data.cpu().numpy().reshape(-1)\n",
    "    return arr\n",
    "\n",
    "\n",
    "\n",
    "def numpy_to_torch(arr, net):\n",
    "    device = next(net.parameters()).device\n",
    "    arr = torch.from_numpy(arr).to(device)\n",
    "    sizes, bounds = get_sizes(net.parameters())\n",
    "    assert len(arr) == sum(sizes)\n",
    "    for bound, var in zip(bounds, net.parameters()):\n",
    "        vnp = var.data.view(-1)\n",
    "        vnp[:] = arr[bound[0] : bound[1]]\n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def FedAvg(epochs, client_data, model, lr, global_lr, local_iterations):\n",
    "    protected_attributes = [0,1]\n",
    "    num_instances = 0\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_instances += running_instance\n",
    "    print(num_instances)\n",
    "    num_pos_instances = 0\n",
    "    num_neg_instances = 0\n",
    "    num_test_pos_instances = 0\n",
    "    num_test_neg_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_neg_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_neg_instances += running_instance\n",
    "    \n",
    "    \n",
    "    for i in range(epochs):\n",
    "#         if i % 20 == 19:\n",
    "#             lr /= 2\n",
    "        all_client_updates = 0\n",
    "        for client in client_data.values():\n",
    "            client_model_copy = deepcopy(model)\n",
    "            optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "            for j in range(local_iterations):\n",
    "                optimizer.zero_grad()\n",
    "                model_output = client_model_copy(client['x'].float())\n",
    "#                 print(model_output)\n",
    "                model_prediction = model_output.argmax(1)\n",
    "#                 import pdb;\n",
    "#                 pdb.set_trace()\n",
    "                loss = nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "#                 loss = (nn.CrossEntropyLoss()(model_output[client['a'] == 0], client['y'][client['a'] == 0]) + nn.CrossEntropyLoss()(model_output[client['a'] == 1], client['y'][client['a'] == 1])) / 2\n",
    "\n",
    "#                 print(loss)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "            all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "        all_client_updates /= len(client_data)\n",
    "        w = torch_to_numpy(model.parameters())\n",
    "        w += all_client_updates * global_lr\n",
    "        numpy_to_torch(w, model)\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        \n",
    "#         print('Printing train accuracy...')\n",
    "        for client in client_data.values():\n",
    "            model_output = client_model_copy(client['x'].float())\n",
    "            model_prediction = model_output.argmax(1)\n",
    "            correct += model_prediction.eq(client['y']).sum()\n",
    "            total += len(client['y'])\n",
    "#         print(correct, total)\n",
    "#         print(correct.item()*1.0/total)\n",
    "\n",
    "        if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                DP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    acceptance_rate = F.softmax(model_output)[:,1]\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "#                     PP += model_prediction.sum()\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0.)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "                        running_theta_test[idx] += the\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)]) < 1:\n",
    "                            running_theta_test[idx+2] += torch.tensor(0.)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)])\n",
    "                        running_theta_test[idx+2] += the\n",
    "                        DP_by_group[idx] += acceptance_rate[client['a_test'] == p].sum()\n",
    "                        total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "                    SP_by_group[idx] = (DP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_pos_instances[idx]\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx+2] /= num_test_neg_instances[idx]\n",
    "                print(\"TP loss: \", running_theta_test[:2])\n",
    "                print(\"FP loss: \", running_theta_test[2:])\n",
    "                print(SP_by_group)\n",
    "                max_sps.append(running_theta_test)\n",
    "        \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FedAvg(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold):\n",
    "    theta = torch.tensor(np.zeros(len(protected_attributes)))\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_instances = 0\n",
    "    num_test_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a_test'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_instances += running_instance\n",
    "    print(num_instances, num_test_instances)\n",
    "    \n",
    "    for k in range(rounds):\n",
    "        lmbda = B*theta.exp()/(1+theta.exp().sum())\n",
    "        print(lmbda)\n",
    "        grad_theta = 0\n",
    "#         initial_lr /= 5\n",
    "        lr = initial_lr\n",
    "        for i in range(epochs):\n",
    "#             if i % 20 == 19:\n",
    "#                 lr /= 2\n",
    "            all_client_updates = 0\n",
    "            for client in client_data.values():\n",
    "                client_model_copy = deepcopy(model)\n",
    "                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "                for j in range(local_iterations):\n",
    "                    optimizer.zero_grad()\n",
    "                    model_output = client_model_copy(client['x'].float())\n",
    "    #                 print(model_output)\n",
    "#                     model_prediction = model_output.argmax(1)\n",
    "#                     sum_positive_prediction = F.softmax(model_output[:,1]).sum() / N\n",
    "#                     sum_positive_prediction_per_group = 0\n",
    "                    loss = 0\n",
    "                    for p, l, ins in zip(protected_attributes, lmbda, num_instances):\n",
    "                        if len(model_output[client['a'] == p]) < 1:\n",
    "                            continue                        \n",
    "                        ################## Global Fairness ####################\n",
    "                        loss += l * (nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins - threshold)  \n",
    "                        ################## Local Fairness ####################\n",
    "#                         loss += l * (nn.CrossEntropyLoss(reduction='mean')(model_output[client['a'] == p], client['y'][client['a'] == p]) - threshold)\n",
    "#                         sum_positive_prediction_per_group += l * (sum_positive_prediction - torch.tensor([model_output[idx,1] if (client['a'][idx] == p) else 0 for idx in range(len(model_output))]).sum()/N_a[p])\n",
    "#                         sum_positive_prediction_per_group+=(torch.tensor([model_output[idx,1] if (client['a'][idx] == p) else 0 for idx in range(len(model_output))]).sum()/N_a[p])\n",
    "    #                 print(model_output, client['y'])\n",
    "#                     reg = torch.tensor(sum_positive_prediction - sum_positive_prediction_per_group)\n",
    "#                     reg = sum_positive_prediction_per_group\n",
    "#                     print(reg)\n",
    "#                     print(loss, nn.CrossEntropyLoss()(model_output, client['y']))\n",
    "                    loss += nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "#                 if i == epochs-1:\n",
    "#                     running_theta = []\n",
    "#                     for p in protected_attributes:\n",
    "#                         if len(model_output[client['a'] == p]) < 1:\n",
    "#                             running_theta.append(torch.tensor(0))\n",
    "#                             continue\n",
    "#                         running_theta.append(nn.CrossEntropyLoss()(model_output[client['a'] == p], client['y'][client['a'] == p]))\n",
    "#                     running_theta = torch.tensor(running_theta)\n",
    "# #                     print(running_theta)\n",
    "#                     grad_theta = grad_theta + running_theta\n",
    "                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "            all_client_updates /= len(client_data)\n",
    "            w = torch_to_numpy(model.parameters())\n",
    "            w += all_client_updates\n",
    "            numpy_to_torch(w, model)\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                for client in client_data.values():\n",
    "                    running_theta = []\n",
    "                    model_output = model(client['x'].float())\n",
    "                    for p, ins in zip(protected_attributes, num_instances):\n",
    "                        if len(model_output[client['a'] == p]) < 1:\n",
    "                            running_theta.append(torch.tensor(0))\n",
    "                            continue\n",
    "                        \n",
    "                        ############## Global Fairness ########\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins - threshold\n",
    "                        \n",
    "                        ############## Local Fairness ###########\n",
    "#                         the = nn.CrossEntropyLoss(reduction='mean')(model_output[client['a'] == p], client['y'][client['a'] == p]) - threshold\n",
    "                        running_theta.append(the)\n",
    "                    running_theta = torch.tensor(running_theta)\n",
    "                    grad_theta = grad_theta + running_theta\n",
    "                \n",
    "                iterates += 1\n",
    "                average_iterate += torch_to_numpy(model.parameters())\n",
    "                numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "            \n",
    "#             print('Printing train accuracy..')\n",
    "#             for client in client_data.values():\n",
    "#                 model_output = model(client['x'].float())\n",
    "#                 model_prediction = model_output.argmax(1)\n",
    "#                 correct += model_prediction.eq(client['y']).sum()\n",
    "#                 total += len(client['y'])\n",
    "#     #         print(correct, total)\n",
    "#             print(correct.item()*1.0/total)\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                DP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "                    acceptance_rate = F.softmax(model_output)[:,1]\n",
    "                    PP += model_prediction.sum()\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[client['a_test'] == p]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a_test'] == p], client['y_test'][client['a_test'] == p])\n",
    "                        running_theta_test[idx] += the\n",
    "                        DP_by_group[idx] += acceptance_rate[client['a_test'] == p].sum()\n",
    "                        total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "                    SP_by_group[idx] = (DP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_instances[idx]\n",
    "                print(running_theta_test)\n",
    "                max_sps.append(running_theta_test)\n",
    "                print('Max SP gap: ', SP_by_group)\n",
    "#                 print('SP std: ', np.array(SP_by_group[[0,]).std())\n",
    "        \n",
    "        for idx,p in enumerate(protected_attributes):\n",
    "            grad_theta[idx] = grad_theta[idx] / num_instances[idx]\n",
    "#         print(grad_theta)\n",
    "        theta += lr_theta * grad_theta\n",
    "#         print(theta, grad_theta)\n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FedAvg_TP_FP(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, pos_threshold, neg_threshold):\n",
    "    TP = True\n",
    "    FP = False\n",
    "    weighted = False\n",
    "    if TP and FP:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)*2))\n",
    "    else:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)))\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_pos_instances = 0\n",
    "    num_neg_instances = 0\n",
    "    num_test_pos_instances = 0\n",
    "    num_test_neg_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_neg_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_neg_instances += running_instance\n",
    "    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)\n",
    "    \n",
    "    for k in range(rounds):\n",
    "        lmbda = B*theta.exp()/(1+theta.exp().sum())\n",
    "        print(theta)\n",
    "        grad_theta = 0\n",
    "#         initial_lr /= 2\n",
    "        lr = initial_lr\n",
    "        for i in range(epochs):\n",
    "#             if i % 25 == 24:\n",
    "#                 lr /= 2\n",
    "            all_client_updates = 0\n",
    "            for client in client_data.values():\n",
    "                client_model_copy = deepcopy(model)\n",
    "                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "                for j in range(local_iterations):\n",
    "                    optimizer.zero_grad()\n",
    "                    model_output = client_model_copy(client['x'].float())\n",
    "                    loss = 0\n",
    "                    if TP:\n",
    "                        for p, l, ins in zip(protected_attributes, lmbda[:2], num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) / ins - pos_threshold)  \n",
    "                    if FP:\n",
    "                        if TP:\n",
    "                            target_dual = lmbda[2:]\n",
    "                        else:\n",
    "                            target_dual = lmbda[:2]\n",
    "                        for p, l, ins in zip(protected_attributes, target_dual, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) / ins - neg_threshold)  \n",
    "                    if weighted:\n",
    "                        loss += (nn.CrossEntropyLoss()(model_output[client['y'] == 0], client['y'][client['y'] == 0]) + nn.CrossEntropyLoss()(model_output[client['y'] == 1], client['y'][client['y'] == 1]))/2\n",
    "                    else:\n",
    "                        loss += nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "            all_client_updates /= len(client_data)\n",
    "            w = torch_to_numpy(model.parameters())\n",
    "            w += all_client_updates\n",
    "            numpy_to_torch(w, model)\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                for client in client_data.values():\n",
    "                    running_theta = []\n",
    "                    model_output = model(client['x'].float())\n",
    "                    if TP:\n",
    "                        for p, ins in zip(protected_attributes, num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0.))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) / ins - pos_threshold\n",
    "                            running_theta.append(the)\n",
    "                    if FP:\n",
    "                        for p, ins in zip(protected_attributes, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0.))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) / ins - neg_threshold\n",
    "                            running_theta.append(the)\n",
    "                    running_theta = torch.tensor(running_theta)\n",
    "                    grad_theta = grad_theta + running_theta\n",
    "                \n",
    "                iterates += 1\n",
    "                average_iterate += torch_to_numpy(model.parameters())\n",
    "                numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                DP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                EO_by_group = [0]*9\n",
    "                total_pos_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "                    acceptance_rate = F.softmax(model_output)[:,1]\n",
    "#                     PP += model_prediction.sum()\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0.)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "                        running_theta_test[idx] += the\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)]) < 1:\n",
    "                            running_theta_test[idx+2] += torch.tensor(0.)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)])\n",
    "                        running_theta_test[idx+2] += the\n",
    "                        DP_by_group[idx] += acceptance_rate[client['a_test'] == p].sum()\n",
    "                        EO_by_group[idx] += acceptance_rate[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)].sum()\n",
    "                        total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                        total_pos_by_group[idx] += len(model_prediction[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "#                     print(model_prediction)\n",
    "#                         PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "#                         total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                EO = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "                    SP_by_group[idx] = (DP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    EO[idx] = (EO_by_group[idx] / total_pos_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_pos_instances[idx]\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (DP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx+2] /= num_test_neg_instances[idx]\n",
    "                print(\"TP loss: \", running_theta_test[:2])\n",
    "                print(\"FP loss: \", running_theta_test[2:])\n",
    "                max_sps.append(running_theta_test)\n",
    "                print(EO)\n",
    "#                 print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "#                 print('SP std: ', np.array(SP_by_group).std())\n",
    "        \n",
    "#         for idx,p in enumerate(protected_attributes):\n",
    "#             if TP:\n",
    "#                 grad_theta[idx] = grad_theta[idx] / num_pos_instances[idx]\n",
    "#             if FP:\n",
    "#                 target = 2 if TP else 0\n",
    "#                 grad_theta[idx+target] = grad_theta[idx+target] / num_neg_instances[idx]\n",
    "#         print(grad_theta)\n",
    "        theta += lr_theta * grad_theta\n",
    "#         print(theta, grad_theta)\n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FedAvg_TP_FP_local(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, pos_threshold, neg_threshold):\n",
    "    TP = False\n",
    "    FP = True\n",
    "    weighted = False\n",
    "    if TP and FP:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)*2))\n",
    "    else:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)))\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_pos_instances = 0\n",
    "    num_neg_instances = 0\n",
    "    num_test_pos_instances = 0\n",
    "    num_test_neg_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_neg_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_neg_instances += running_instance\n",
    "    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)\n",
    "    \n",
    "    for k in range(rounds):\n",
    "        lmbda = B*theta.exp()/(1+theta.exp().sum())\n",
    "        print(theta)\n",
    "        grad_theta = 0\n",
    "#         initial_lr /= 5\n",
    "        lr = initial_lr\n",
    "        for i in range(epochs):\n",
    "#             if i % 20 == 19:\n",
    "#                 lr /= 2\n",
    "            all_client_updates = 0\n",
    "            for client in client_data.values():\n",
    "                client_model_copy = deepcopy(model)\n",
    "                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "                for j in range(local_iterations):\n",
    "                    optimizer.zero_grad()\n",
    "                    model_output = client_model_copy(client['x'].float())\n",
    "                    loss = 0\n",
    "                    if TP:\n",
    "                        for p, l, ins in zip(protected_attributes, lmbda[:2], num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) - pos_threshold / 50)  \n",
    "                    if FP:\n",
    "                        if TP:\n",
    "                            target_dual = lmbda[2:]\n",
    "                        else:\n",
    "                            target_dual = lmbda[:2]\n",
    "                        for p, l, ins in zip(protected_attributes, target_dual, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) - neg_threshold / 50)  \n",
    "                    if weighted:\n",
    "                        loss += (nn.CrossEntropyLoss()(model_output[client['y'] == 0], client['y'][client['y'] == 0]) + nn.CrossEntropyLoss()(model_output[client['y'] == 1], client['y'][client['y'] == 1]))/2\n",
    "                    else:\n",
    "                        loss += nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "            all_client_updates /= len(client_data)\n",
    "            w = torch_to_numpy(model.parameters())\n",
    "            w += all_client_updates\n",
    "            numpy_to_torch(w, model)\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                for client in client_data.values():\n",
    "                    running_theta = []\n",
    "                    model_output = model(client['x'].float())\n",
    "                    if TP:\n",
    "                        for p, ins in zip(protected_attributes, num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) / ins - pos_threshold / 50\n",
    "                            running_theta.append(the)\n",
    "                    if FP:\n",
    "                        for p, ins in zip(protected_attributes, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) / ins - neg_threshold / 50\n",
    "                            running_theta.append(the)\n",
    "                    running_theta = torch.tensor(running_theta)\n",
    "                    grad_theta = grad_theta + running_theta\n",
    "                \n",
    "                iterates += 1\n",
    "                average_iterate += torch_to_numpy(model.parameters())\n",
    "                numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                PP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "#                     PP += model_prediction.sum()\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0.)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "                        running_theta_test[idx] += the\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)]) < 1:\n",
    "                            running_theta_test[idx+2] += torch.tensor(0.)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)])\n",
    "                        running_theta_test[idx+2] += the\n",
    "#                         PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "#                         total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_pos_instances[idx]\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx+2] /= num_test_neg_instances[idx]\n",
    "                print(\"TP loss: \", running_theta_test[:2])\n",
    "                print(\"FP loss: \", running_theta_test[2:])\n",
    "                max_sps.append(running_theta_test)\n",
    "#                 print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "#                 print('SP std: ', np.array(SP_by_group).std())\n",
    "        \n",
    "#         for idx,p in enumerate(protected_attributes):\n",
    "#             if TP:\n",
    "#                 grad_theta[idx] = grad_theta[idx] / num_pos_instances[idx]\n",
    "#             if FP:\n",
    "#                 target = 2 if TP else 0\n",
    "#                 grad_theta[idx+target] = grad_theta[idx+target] / num_neg_instances[idx]\n",
    "#         print(grad_theta)\n",
    "        theta += lr_theta * grad_theta\n",
    "#         print(theta, grad_theta)\n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def FedMinMax(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold):\n",
    "    theta = torch.tensor(np.ones(len(protected_attributes)))*1.0 / len(protected_attributes)\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_instances = 0\n",
    "    num_test_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a_test'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_instances += running_instance\n",
    "    print(num_instances, num_test_instances)\n",
    "    \n",
    "    for i in range(epochs):\n",
    "        lr = initial_lr\n",
    "        lmbda = theta\n",
    "#         if i % 20 == 19:\n",
    "#             lr /= 2\n",
    "        all_client_updates = 0\n",
    "        for client in client_data.values():\n",
    "            client_model_copy = deepcopy(model)\n",
    "            optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "            for j in range(local_iterations):\n",
    "                optimizer.zero_grad()\n",
    "                model_output = client_model_copy(client['x'].float())\n",
    "                loss = 0\n",
    "                for p, l, ins in zip(protected_attributes, lmbda, num_instances):\n",
    "                    if len(model_output[client['a'] == p]) < 1:\n",
    "                        continue                        \n",
    "                    loss += l * (nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins)  \n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "            all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "        all_client_updates /= len(client_data)\n",
    "        w = torch_to_numpy(model.parameters())\n",
    "        w += all_client_updates\n",
    "        numpy_to_torch(w, model)\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        grad_theta = 0\n",
    "\n",
    "        for client in client_data.values():\n",
    "            running_theta = []\n",
    "            model_output = model(client['x'].float())\n",
    "            for p, ins in zip(protected_attributes, num_instances):\n",
    "                if len(model_output[client['a'] == p]) < 1:\n",
    "                    running_theta.append(torch.tensor(0))\n",
    "                    continue\n",
    "                the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins\n",
    "                running_theta.append(the)\n",
    "            running_theta = torch.tensor(running_theta)\n",
    "            grad_theta = grad_theta + running_theta\n",
    "\n",
    "        iterates += 1\n",
    "        average_iterate += torch_to_numpy(model.parameters())\n",
    "        numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "\n",
    "        if i == epochs-1:\n",
    "            print('Printing test accuracy')\n",
    "            correct = 0\n",
    "            PP = 0\n",
    "            total = 0\n",
    "            PP_by_group = [0]*9\n",
    "            total_by_group = [0]*9\n",
    "            running_theta_test = [0]*9\n",
    "            for client in client_data.values():\n",
    "                model_output = model(client['x_test'].float())\n",
    "                model_prediction = model_output.argmax(1)\n",
    "                PP += model_prediction.sum()\n",
    "\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "                    if len(model_output[client['a_test'] == p]) < 1:\n",
    "                        running_theta_test[idx] += torch.tensor(0)\n",
    "                        continue\n",
    "                    the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a_test'] == p], client['y_test'][client['a_test'] == p])\n",
    "                    running_theta_test[idx] += the\n",
    "                    PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "                    total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                correct += model_prediction.eq(client['y_test']).sum()\n",
    "                total += len(client['y_test'])\n",
    "            print(correct.item()*1.0/total)\n",
    "            test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "            SP_by_group = [0]*9\n",
    "            for idx,p in enumerate(protected_attributes):\n",
    "                SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                running_theta_test[idx] /= num_test_instances[idx]\n",
    "            print(running_theta_test)\n",
    "            max_sps.append(running_theta_test)\n",
    "            print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "            print('SP std: ', np.array(SP_by_group).std())\n",
    "\n",
    "        for idx,p in enumerate(protected_attributes):\n",
    "            grad_theta[idx] = grad_theta[idx] / num_instances[idx]\n",
    "        theta += lr_theta * grad_theta\n",
    "#         theta /= theta.sum()  \n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogisticRegression(12,2)\n",
    "epochs = 5\n",
    "lr = 0.1\n",
    "threshold = 0.3\n",
    "pos_threshold = 1\n",
    "neg_threshold = 0\n",
    "global_lr = 1\n",
    "lr_theta = 0.005\n",
    "local_iterations = 3\n",
    "rounds = 200\n",
    "B = 5"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
