{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a238f0af-5110-4ed1-a520-c3131bd0a1be",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-22 14:19:30.935052: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-05-22 14:19:31.471947: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "/home/leiding/pyotrch/lib/python3.8/site-packages/transformers/generation_utils.py:24: FutureWarning: Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import GenerationMixin` instead.\n",
      "  warnings.warn(\n",
      "/home/leiding/pyotrch/lib/python3.8/site-packages/transformers/generation_tf_utils.py:24: FutureWarning: Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import TFGenerationMixin` instead.\n",
      "  warnings.warn(\n",
      "/tmp/ipykernel_410150/685032672.py:28: DeprecationWarning: Please use `pearsonr` from the `scipy.stats` namespace, the `scipy.stats.stats` namespace is deprecated.\n",
      "  from scipy.stats.stats import pearsonr\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import gensim\n",
    "import codecs\n",
    "import json\n",
    "from gensim.models.keyedvectors import Word2VecKeyedVectors\n",
    "from gensim.models import KeyedVectors\n",
    "import numpy as np\n",
    "import random\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from sklearn import cluster\n",
    "from sklearn import metrics\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.svm import LinearSVC, SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "import numpy as np\n",
    "from docopt import docopt\n",
    "import torch\n",
    "from transformers import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy\n",
    "from scipy import linalg\n",
    "from scipy import sparse\n",
    "from scipy.stats.stats import pearsonr\n",
    "import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "from torch import utils\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import time\n",
    "from mlxtend.math import vectorspace_orthonormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77bb6aa6-ec9c-4093-9981-d7cc3cf083f5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def load_dataset(path):\n",
    "    \n",
    "    with open(path, \"rb\") as f:\n",
    "        \n",
    "        data = pickle.load(f)\n",
    "    return data\n",
    "\n",
    "def load_dictionary(path):\n",
    "    \n",
    "    with open(path, \"r\", encoding = \"utf-8\") as f:\n",
    "        \n",
    "        lines = f.readlines()\n",
    "        \n",
    "    k2v, v2k = {}, {}\n",
    "    for line in lines:\n",
    "        \n",
    "        k,v = line.strip().split(\"\\t\")\n",
    "        v = int(v)\n",
    "        k2v[k] = v\n",
    "        v2k[v] = k\n",
    "    \n",
    "    return k2v, v2k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "700d986c-7d4f-460d-b2f4-24c36a379e77",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "path = \"./data/\"\n",
    "x_train = np.load(path + \"train_cls.npy\")\n",
    "x_dev = np.load(path + \"dev_cls.npy\")\n",
    "x_test = np.load(path + \"test_cls.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3a11da2b-a18c-4def-80d0-a3df8d9fbd91",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train = pd.read_csv('./data/wiki_train.csv')\n",
    "test = pd.read_csv('./data/wiki_test.csv')\n",
    "dev = pd.read_csv('./data/wiki_dev.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c64577d0-f23d-474e-9f17-773c771f5955",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_train = train['is_toxic']\n",
    "y_dev = dev['is_toxic']\n",
    "y_test = test['is_toxic']\n",
    "# encode categorical label into numbers\n",
    "from sklearn import preprocessing\n",
    "le = preprocessing.LabelEncoder()\n",
    "y_train = le.fit_transform(y_train).reshape((-1, 1))\n",
    "y_dev = le.transform(y_dev).reshape((-1, 1))\n",
    "y_test = le.transform(y_test).reshape((-1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "628d2f39-c0d1-4844-8168-d59411e0aee6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "z_train = np.load(path + \"Z_train.npy\")\n",
    "z_dev = np.load(path + \"Z_dev.npy\")\n",
    "z_test = np.load(path + \"Z_test.npy\")\n",
    "\n",
    "idx_train = (z_train[:,-1]!=0)\n",
    "idx_test = (z_test[:,-1]!=0)\n",
    "idx_dev = (z_dev[:,-1]!=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "629d495c-2a76-4af0-8180-7b8480d69627",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "z_train = z_train[idx_train]\n",
    "z_test = z_test[idx_test]\n",
    "z_dev = z_dev[idx_dev]\n",
    "x_train = x_train[idx_train]\n",
    "x_test = x_test[idx_test]\n",
    "x_dev = x_dev[idx_dev]\n",
    "y_train = y_train[idx_train]\n",
    "y_test = y_test[idx_test]\n",
    "y_dev = y_dev[idx_dev]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a4960574-fc42-4926-adc8-39a10bafb443",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "z_dev = z_dev[:,0:-1]/z_dev[:,-1].reshape(-1,1)\n",
    "z_train = z_train[:,0:-1]/z_train[:,-1].reshape(-1,1)\n",
    "z_test = z_test[:,0:-1]/z_test[:,-1].reshape(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "03da8403-8af8-4bf2-a22e-68eb5ec2e188",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32127, 50)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.shape(z_dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b2fc5175-57ba-4418-b8d2-ec24d8a6dcca",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "95679"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(z_train[:,-1].reshape(-1,1)==0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4f102529-6cd2-4d16-ac1d-99f8b26ed1f3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4171fd62-8f0c-4ece-ba23-17069db16401",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "z_train2 = z_train[:,(np.max(z_train,0) != 0)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7fa87767-19c4-48cb-8f4e-9576bc93f40e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.2       , 0.5       , 0.16666667, 0.06666667, 0.08695652,\n",
       "       0.25      , 0.16666667, 0.02380952, 0.2       , 0.16666667,\n",
       "       0.01960784, 0.25      , 0.2       , 0.5       , 0.5       ,\n",
       "       0.33333333, 0.25      , 0.25      , 0.14285714, 0.04444444,\n",
       "       0.14285714, 0.14285714, 0.25      , 0.5       , 0.16666667,\n",
       "       0.28571429, 0.04347826, 0.2       , 0.33333333, 0.33333333,\n",
       "       0.2       , 0.33333333, 0.06666667, 0.2       , 0.06666667,\n",
       "       0.13793103, 0.33333333, 0.15384615, 0.33333333, 0.16666667,\n",
       "       0.07142857, 0.00246305, 0.05555556, 0.05263158, 0.16666667,\n",
       "       0.2       ])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(z_train2,0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "385b858c-d26d-4caf-8e1a-42225be73d09",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1549ba18-7581-4d78-b452-0cc62053ebaa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d66f3f46-b6c2-4b0c-b502-917c1f627020",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def slice_freq(Y, interval):\n",
    "    y_f = pd.Series(Y)\n",
    "    a = pd.cut(y_f,interval)\n",
    "    b=a.value_counts()\n",
    "    frequency = b.sort_index().values\n",
    "    return frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4b2901a9-0adf-4179-93e2-0faac79b4367",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_rowspace_projection(W: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    :param W: the matrix over its nullspace to project\n",
    "    :return: the projection matrix over the rowspace\n",
    "    \"\"\"\n",
    "    if np.allclose(W, 0):\n",
    "        w_basis = np.zeros_like(W.T)\n",
    "    else:\n",
    "        w_basis = scipy.linalg.orth(W.T) # orthogonal basis\n",
    "\n",
    "    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace\n",
    "\n",
    "    return P_W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "af02728c-d05e-449f-8058-e6bdd515d250",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def SDR_multi_new(X_train, Y, K, h):\n",
    "\n",
    "    n = X_train.shape[0]\n",
    "    d_x = X_train.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "    \n",
    "    X = X_train\n",
    "\n",
    "    interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        #interval = np.linspace(np.min(Y[:,i])-0.1, np.max(Y[:,i])+0.1, h+1,endpoint=True)\n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            if ph == 0:\n",
    "                mh = np.zeros(d_x)\n",
    "            else:\n",
    "                index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "                mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + Gamma/d_y\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    b1 = v_PMS[:,:K] #+ #0.001*np.random.rand(d_x,K)\n",
    "    b2 = v_PMS[:,K:] #+ #0.001*np.random.rand(d_x,d_x-K)\n",
    "    #print(b1)\n",
    "    beta1 = get_rowspace_projection(b1.T) #vectorspace_orthonormalization(v_PMS[:,:K])\n",
    "    beta2 = get_rowspace_projection(b2.T)\n",
    "    return beta1, beta2  #np.dot(beta1,beta1.T), np.eye(d_x)-np.dot(beta1,beta1.T) \n",
    "\n",
    "def debias_proj(X_t, Y1_train, Y2_train, Iter, q, h):\n",
    "    n = len(X_t)\n",
    "    d = X_t.shape[1]\n",
    "    X1 = X_t-np.mean(X_t,0).reshape(1,d)\n",
    "    Cx = np.dot(X1.T,X1)/n\n",
    "#     la, v = np.linalg.eig(Cx)\n",
    "#     la = np.abs(la)\n",
    "#     v = v.real\n",
    "#     Cx12 = np.dot(np.dot(v,np.diag(la**(0.5))),v.T)\n",
    "    InvC = np.linalg.pinv(Cx,rcond=1e-8)\n",
    "    X_t = np.dot(X_t, InvC)\n",
    "\n",
    "    X_ = X_t\n",
    "    Y2_ = Y2_train\n",
    "    beta1_g, beta2_g = SDR_multi_new(X_, Y2_, q, h)\n",
    "\n",
    "    return beta1_g, beta2_g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5a8467c5-4c77-471e-921b-3116c5320599",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "idx = np.random.rand(x_train.shape[0]) < 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "bc09a179-4240-4bd6-a5b3-d0ec792070a3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 5 seconds\n",
      "0.9520431862406629\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:    5.4s finished\n"
     ]
    }
   ],
   "source": [
    "clf_prof = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "\n",
    "clf_prof.fit(x_train[idx], y_train[idx])\n",
    "print(clf_prof.score(x_test, y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "db795d44-f5cb-4737-8ad2-e5e263088b8b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(95679, 768)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4e900bd3-3c1f-46f5-9f4d-2458fb45e9d9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(95679, 1)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "000433c8-b971-4edd-80d0-121a0db6c294",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(31862, 50)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "65d9046a-25e6-4baf-9be0-aceba5d1b8c0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_pred_after = clf_prof.predict(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "82ea0ce5-5831-44ff-887e-0357954500f9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "29474"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(y_pred_after==0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "6f7dcf83-e0b8-4f17-9178-79be05faca81",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def GAPs(y_pred, y_true, z):\n",
    "    n, p = z.shape\n",
    "    y_pred = y_pred.reshape(n)\n",
    "    y_true = y_true.reshape(n)\n",
    "    FPR0 = np.mean(y_pred[y_true>0]==0)\n",
    "    print(FPR0)\n",
    "    FPR_p = np.zeros(p)\n",
    "    for i in range(p):\n",
    "        idx = (z[:,i].reshape(-1)>0)\n",
    "        if np.sum(idx)==0:\n",
    "            FPR_p[i] = FPR0\n",
    "        else:\n",
    "            y_pred_ = y_pred[idx]\n",
    "            y_true_ = y_true[idx]\n",
    "            if np.sum(y_pred_[y_true_>0]==0) == 0:\n",
    "                FPR_p[i] = FPR0\n",
    "            else:\n",
    "                FPR_p[i] = np.mean(y_pred_[y_true_>0]==0)\n",
    "        #print(FPR_p[i])\n",
    "    return np.sum(np.abs(FPR_p-FPR0))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "149c7d7c-7574-4dd1-a419-cf7594e5df6e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a53af82f-9f37-44e2-b07a-0fd89f1a47c9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3589238845144357\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "7.3371804643178296"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GAPs(y_pred_after, y_test, z_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f276be7b-e765-4baa-9141-7fd02d94ab97",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "9f803915-1d3d-4032-853e-6cdd4d45e418",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def SDR_PMS(X_train, Y, K, h):\n",
    "\n",
    "    n = X_train.shape[0]\n",
    "    d_x = X_train.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "    \n",
    "    X = X_train\n",
    "\n",
    "    #interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        interval = np.linspace(np.min(Y[:,i])-0.01, np.max(Y[:,i])+0.01, h+1,endpoint=True)\n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            if ph == 0:\n",
    "                mh = np.zeros(d_x)\n",
    "            else:\n",
    "                index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "                mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + Gamma/d_y\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    b1 = v_PMS[:,:K] \n",
    "    return b1\n",
    "\n",
    "def get_PMS(X_t, Y2_train, q, h):\n",
    "    n = len(X_t)\n",
    "    d = X_t.shape[1]\n",
    "    X1 = X_t-np.mean(X_t,0).reshape(1,d)\n",
    "    Cx = np.dot(X1.T,X1)/n\n",
    "    InvC = np.linalg.pinv(Cx,rcond=1e-8)\n",
    "    X_t = np.dot(X_t, InvC)\n",
    "\n",
    "    X_ = X_t\n",
    "    Y2_ = Y2_train\n",
    "    \n",
    "    pp_list = []\n",
    "    for i in range(Y2_.shape[1]):\n",
    "        cc = SDR_PMS(X_, Y2_[:,i].reshape(1,-1).T, 768, h)\n",
    "        pp_list.append(cc)\n",
    "    return pp_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8c07dc1c-c067-4f3b-9168-8ed9eabdd26c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def new_Z(X,Z):\n",
    "    Z[Z>0] = 1\n",
    "    Z_ = Z.copy()\n",
    "    K = np.shape(Z)[1]\n",
    "    for i in range(K):\n",
    "        lr = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 20)\n",
    "        lr.fit(X, Z[:,i])\n",
    "        Z_[:,i] = np.log(lr.predict_proba(X)[:,0])\n",
    "    return Z_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e696911e-9cf2-4d62-81dc-02d40910c59f",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.29878992\n",
      "Epoch 3, change: 0.17384325\n",
      "Epoch 4, change: 0.10131030\n",
      "Epoch 5, change: 0.08043452\n",
      "Epoch 6, change: 0.05668465\n",
      "Epoch 7, change: 0.04847571\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.19429012\n",
      "Epoch 3, change: 0.13390577\n",
      "Epoch 4, change: 0.11251716\n",
      "Epoch 5, change: 0.09308229\n",
      "Epoch 6, change: 0.08010010\n",
      "Epoch 7, change: 0.06498801\n",
      "Epoch 8, change: 0.05482737\n",
      "Epoch 9, change: 0.04907443\n",
      "Epoch 10, change: 0.04318935\n",
      "Epoch 11, change: 0.03788785\n",
      "Epoch 12, change: 0.03651959\n",
      "Epoch 13, change: 0.03465665\n",
      "Epoch 14, change: 0.03284229\n",
      "Epoch 15, change: 0.03030117\n",
      "Epoch 16, change: 0.02734818\n",
      "Epoch 17, change: 0.02670017\n",
      "Epoch 18, change: 0.02614135\n",
      "Epoch 19, change: 0.02691340\n",
      "Epoch 20, change: 0.02503308\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.26725543\n",
      "Epoch 3, change: 0.15633385\n",
      "Epoch 4, change: 0.13208780\n",
      "Epoch 5, change: 0.09603459\n",
      "Epoch 6, change: 0.07061110\n",
      "Epoch 7, change: 0.06911717\n",
      "Epoch 8, change: 0.05717146\n",
      "Epoch 9, change: 0.05111691\n",
      "Epoch 10, change: 0.04764172\n",
      "Epoch 11, change: 0.04127735\n",
      "Epoch 12, change: 0.03636226\n",
      "Epoch 13, change: 0.03406669\n",
      "Epoch 14, change: 0.03057568\n",
      "Epoch 15, change: 0.03044359\n",
      "Epoch 16, change: 0.02924166\n",
      "Epoch 17, change: 0.02815950\n",
      "Epoch 18, change: 0.02705636\n",
      "Epoch 19, change: 0.02703557\n",
      "Epoch 20, change: 0.02617593\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.21482301\n",
      "Epoch 3, change: 0.15042946\n",
      "Epoch 4, change: 0.10038739\n",
      "Epoch 5, change: 0.07630071\n",
      "Epoch 6, change: 0.07916271\n",
      "Epoch 7, change: 0.06558553\n",
      "Epoch 8, change: 0.06091453\n",
      "Epoch 9, change: 0.05426849\n",
      "Epoch 10, change: 0.04878654\n",
      "Epoch 11, change: 0.04590133\n",
      "Epoch 12, change: 0.04353949\n",
      "Epoch 13, change: 0.03962308\n",
      "Epoch 14, change: 0.03627651\n",
      "Epoch 15, change: 0.03447271\n",
      "Epoch 16, change: 0.03262867\n",
      "Epoch 17, change: 0.03166505\n",
      "Epoch 18, change: 0.03326062\n",
      "Epoch 19, change: 0.02978826\n",
      "Epoch 20, change: 0.02575621\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.10979418\n",
      "Epoch 3, change: 0.08174751\n",
      "Epoch 4, change: 0.05825482\n",
      "Epoch 5, change: 0.04910790\n",
      "Epoch 6, change: 0.04410408\n",
      "Epoch 7, change: 0.03742170\n",
      "Epoch 8, change: 0.03365514\n",
      "Epoch 9, change: 0.03119178\n",
      "Epoch 10, change: 0.02893296\n",
      "Epoch 11, change: 0.02671564\n",
      "Epoch 12, change: 0.02726826\n",
      "Epoch 13, change: 0.02327285\n",
      "Epoch 14, change: 0.02231121\n",
      "Epoch 15, change: 0.02160396\n",
      "Epoch 16, change: 0.02114069\n",
      "Epoch 17, change: 0.01722026\n",
      "Epoch 18, change: 0.01698594\n",
      "Epoch 19, change: 0.01734327\n",
      "Epoch 20, change: 0.01720042\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.18788186\n",
      "Epoch 3, change: 0.07636895\n",
      "Epoch 4, change: 0.06465704\n",
      "Epoch 5, change: 0.05861834\n",
      "Epoch 6, change: 0.05128339\n",
      "Epoch 7, change: 0.04594584\n",
      "Epoch 8, change: 0.04177603\n",
      "Epoch 9, change: 0.03663207\n",
      "Epoch 10, change: 0.03526376\n",
      "Epoch 11, change: 0.03438762\n",
      "Epoch 12, change: 0.03262027\n",
      "Epoch 13, change: 0.03002361\n",
      "Epoch 14, change: 0.02950239\n",
      "Epoch 15, change: 0.02687844\n",
      "Epoch 16, change: 0.02742990\n",
      "Epoch 17, change: 0.02573403\n",
      "Epoch 18, change: 0.02371714\n",
      "Epoch 19, change: 0.02279404\n",
      "Epoch 20, change: 0.02273410\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.20190790\n",
      "Epoch 3, change: 0.11825608\n",
      "Epoch 4, change: 0.10006402\n",
      "Epoch 5, change: 0.07555687\n",
      "Epoch 6, change: 0.06464089\n",
      "Epoch 7, change: 0.06056824\n",
      "Epoch 8, change: 0.05288722\n",
      "Epoch 9, change: 0.04784453\n",
      "Epoch 10, change: 0.04105386\n",
      "Epoch 11, change: 0.04059249\n",
      "Epoch 12, change: 0.03693550\n",
      "Epoch 13, change: 0.03600852\n",
      "Epoch 14, change: 0.03221655\n",
      "Epoch 15, change: 0.03100622\n",
      "Epoch 16, change: 0.03011983\n",
      "Epoch 17, change: 0.03036647\n",
      "Epoch 18, change: 0.02965244\n",
      "Epoch 19, change: 0.02891567\n",
      "Epoch 20, change: 0.02339042\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.23802604\n",
      "Epoch 3, change: 0.15481052\n",
      "Epoch 4, change: 0.10676067\n",
      "Epoch 5, change: 0.08904231\n",
      "Epoch 6, change: 0.08465334\n",
      "Epoch 7, change: 0.07721976\n",
      "Epoch 8, change: 0.06925983\n",
      "Epoch 9, change: 0.06582047\n",
      "Epoch 10, change: 0.06299322\n",
      "Epoch 11, change: 0.05679798\n",
      "Epoch 12, change: 0.05624313\n",
      "Epoch 13, change: 0.05585380\n",
      "Epoch 14, change: 0.04863888\n",
      "Epoch 15, change: 0.04658595\n",
      "Epoch 16, change: 0.04287278\n",
      "Epoch 17, change: 0max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".04325789\n",
      "Epoch 18, change: 0.03943856\n",
      "Epoch 19, change: 0.03465563\n",
      "Epoch 20, change: 0.03362581\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.01264601\n",
      "Epoch 3, change: 0.00817275\n",
      "Epoch 4, change: 0.18308620\n",
      "Epoch 5, change: 0.13379298\n",
      "Epoch 6, change: 0.15195003\n",
      "Epoch 7, change: 0.16541733\n",
      "Epoch 8, change: 0.14497358\n",
      "Epoch 9, change: 0.06707393\n",
      "Epoch 10, change: 0.04695895\n",
      "Epoch 11, change: 0.04106708\n",
      "Epoch 12, change: 0.03489358\n",
      "Epoch 13, change: 0.02910454\n",
      "Epoch 14, change: 0.02468155\n",
      "Epoch 15, change: 0.02471816\n",
      "Epoch 16, change: 0.01843582\n",
      "Epoch 17, change: 0.01449527\n",
      "Epoch 18, change: 0.01428928\n",
      "Epoch 19, change: 0.01397789\n",
      "Epoch 20, change: 0.01387163\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.27497211\n",
      "Epoch 3, change: 0.15939511\n",
      "Epoch 4, change: 0.15389642\n",
      "Epoch 5, change: 0.11515816\n",
      "Epoch 6, change: 0.10326861\n",
      "Epoch 7, change: 0.09527942\n",
      "Epoch 8, change: 0.07941049\n",
      "Epoch 9, change: 0.06705323\n",
      "Epoch 10, change: 0.05800911\n",
      "Epoch 11, change: 0.05115857\n",
      "Epoch 12, change: 0.04635592\n",
      "Epoch 13, change: 0.04057816\n",
      "Epoch 14, change: 0.03693065\n",
      "Epoch 15, change: 0.03606382\n",
      "Epoch 16, change: 0.03501296\n",
      "Epoch 17, change: 0.03337674\n",
      "Epoch 18, change: 0.02707882\n",
      "Epoch 19, change: 0.02713788\n",
      "Epoch 20, change: 0.02580771\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.25652683\n",
      "Epoch 3, change: 0.18295656\n",
      "Epoch 4, change: 0.12315384\n",
      "Epoch 5, change: 0.10904384\n",
      "Epoch 6, change: 0.09564403\n",
      "Epoch 7, change: 0.08741108\n",
      "Epoch 8, change: 0.07619260\n",
      "Epoch 9, change: 0.06587951\n",
      "Epoch 10, change: 0.05725588\n",
      "Epoch 11, change: 0.05260345\n",
      "Epoch 12, change: 0.04925496\n",
      "Epoch 13, change: 0.04335935\n",
      "Epoch 14, change: 0.04113914\n",
      "Epoch 15, change: 0.03949102\n",
      "Epoch 16, change: 0.03725862\n",
      "Epoch 17, change: 0.03150288\n",
      "Epoch 18, change: 0.03079695\n",
      "Epoch 19, change: 0.03032892\n",
      "Epoch 20, change: 0.02885332\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.14679235\n",
      "Epoch 3, change: 0.06153478\n",
      "Epoch 4, change: 0.06104724\n",
      "Epoch 5, change: 0.04659370\n",
      "Epoch 6, change: 0.04218030\n",
      "Epoch 7, change: 0.03557871\n",
      "Epoch 8, change: 0.03408663\n",
      "Epoch 9, change: 0.03313026\n",
      "Epoch 10, change: 0.02826390\n",
      "Epoch 11, change: 0.02652137\n",
      "Epoch 12, change: 0.02459409\n",
      "Epoch 13, change: 0.02399410\n",
      "Epoch 14, change: 0.02112682\n",
      "Epoch 15, change: 0.02104809\n",
      "Epoch 16, change: 0.02127457\n",
      "Epoch 17, change: 0.02114419\n",
      "Epoch 18, change: 0.01753883\n",
      "Epoch 19, change: 0.01706026\n",
      "Epoch 20, change: 0.01727426\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.30532122\n",
      "Epoch 3, change: 0.18711370\n",
      "Epoch 4, change: 0.13303693\n",
      "Epoch 5, change: 0.11394518\n",
      "Epoch 6, change: 0.08909375\n",
      "Epoch 7, change: 0.07646616\n",
      "Epoch 8, change: 0.06470254\n",
      "Epoch 9, change: 0.05709524\n",
      "Epoch 10, change: 0.05478057\n",
      "Epoch 11, change: 0.04820701\n",
      "Epoch 12, change: 0.03932649\n",
      "Epoch 13, change: 0.03830642\n",
      "Epoch 14, change: 0.03601619\n",
      "Epoch 15, change: 0.03540811\n",
      "Epoch 16, change: 0.02613734\n",
      "Epoch 17, change: 0.02575687\n",
      "Epoch 18, change: 0.02440375\n",
      "Epoch 19, change: 0.02451761\n",
      "Epoch 20, change: 0.02356268\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.32305977\n",
      "Epoch 3, change: 0.16981003\n",
      "Epoch 4, change: 0.14362454\n",
      "Epoch 5, change: 0.09196902\n",
      "Epoch 6, change: 0.08284803\n",
      "Epoch 7, change: 0.07695711\n",
      "Epoch 8, change: 0.06933918\n",
      "Epoch 9, change: 0.06720365\n",
      "Epoch 10, change: 0.06008548\n",
      "Epoch 11, change: 0.05391717\n",
      "Epoch 12, change: 0.04815635\n",
      "Epoch 13, change: 0.04358651\n",
      "Epoch 14, change: 0.04127787\n",
      "Epoch 15, change: 0.03926857\n",
      "Epoch 16, change: 0.03552602\n",
      "Epoch 17, change: 0.03681647\n",
      "Epoch 18, change: 0.02811855\n",
      "Epoch 19, change: 0.02530653\n",
      "Epoch 20, change: 0.02318184\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.29045898\n",
      "Epoch 3, change: 0.14109513\n",
      "Epoch 4, change: 0.13340963\n",
      "Epoch 5, change: 0.09863740\n",
      "Epoch 6, change: 0.07959287\n",
      "Epoch 7, change: 0.06060199\n",
      "Epoch 8, change: 0.05589413\n",
      "Epoch 9, change: 0.04614247\n",
      "Epoch 10, change: 0.04324754\n",
      "Epoch 11, change: 0.04030161\n",
      "Epoch 12, change: 0.03662723\n",
      "Epoch 13, change: 0.03577093\n",
      "Epoch 14, change: 0.03340637\n",
      "Epoch 15, change: 0.02916364\n",
      "Epoch 16, change: 0.02767672\n",
      "Epoch 17, change: 0.02610474\n",
      "Epoch 18, change: 0.02574347\n",
      "Epoch 19, change: 0.02407935\n",
      "Epoch 20, change: 0.02093751\n",
      "Epmax_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "och 1, change: 1.00000000\n",
      "Epoch 2, change: 0.22717836\n",
      "Epoch 3, change: 0.11251558\n",
      "Epoch 4, change: 0.09807485\n",
      "Epoch 5, change: 0.07701577\n",
      "Epoch 6, change: 0.06103359\n",
      "Epoch 7, change: 0.06729671\n",
      "Epoch 8, change: 0.05415684\n",
      "Epoch 9, change: 0.04850182\n",
      "Epoch 10, change: 0.04683423\n",
      "Epoch 11, change: 0.04079762\n",
      "Epoch 12, change: 0.03621990\n",
      "Epoch 13, change: 0.03415143\n",
      "Epoch 14, change: 0.03092950\n",
      "Epoch 15, change: 0.03035305\n",
      "Epoch 16, change: 0.02927312\n",
      "Epoch 17, change: 0.02944721\n",
      "Epoch 18, change: 0.02393844\n",
      "Epoch 19, change: 0.02304449\n",
      "Epoch 20, change: 0.02190579\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.29864675\n",
      "Epoch 3, change: 0.20545371\n",
      "Epoch 4, change: 0.11083160\n",
      "Epoch 5, change: 0.10546890\n",
      "Epoch 6, change: 0.08640527\n",
      "Epoch 7, change: 0.07306442\n",
      "Epoch 8, change: 0.06540997\n",
      "Epoch 9, change: 0.05966038\n",
      "Epoch 10, change: 0.05404966\n",
      "Epoch 11, change: 0.04739702\n",
      "Epoch 12, change: 0.04432220\n",
      "Epoch 13, change: 0.03490148\n",
      "Epoch 14, change: 0.03140642\n",
      "Epoch 15, change: 0.03174919\n",
      "Epoch 16, change: 0.02977045\n",
      "Epoch 17, change: 0.02844809\n",
      "Epoch 18, change: 0.02794759\n",
      "Epoch 19, change: 0.02170327\n",
      "Epoch 20, change: 0.01891594\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.27506584\n",
      "Epoch 3, change: 0.16857609\n",
      "Epoch 4, change: 0.11794159\n",
      "Epoch 5, change: 0.09436136\n",
      "Epoch 6, change: 0.07813244\n",
      "Epoch 7, change: 0.06097619\n",
      "Epoch 8, change: 0.05350374\n",
      "Epoch 9, change: 0.04829787\n",
      "Epoch 10, change: 0.04155895\n",
      "Epoch 11, change: 0.03784787\n",
      "Epoch 12, change: 0.03626384\n",
      "Epoch 13, change: 0.03444540\n",
      "Epoch 14, change: 0.02852129\n",
      "Epoch 15, change: 0.02655753\n",
      "Epoch 16, change: 0.02525806\n",
      "Epoch 17, change: 0.02362569\n",
      "Epoch 18, change: 0.02345018\n",
      "Epoch 19, change: 0.02180491\n",
      "Epoch 20, change: 0.01696987\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.28661871\n",
      "Epoch 3, change: 0.15676764\n",
      "Epoch 4, change: 0.16116694\n",
      "Epoch 5, change: 0.11100370\n",
      "Epoch 6, change: 0.07949363\n",
      "Epoch 7, change: 0.06676464\n",
      "Epoch 8, change: 0.06390772\n",
      "Epoch 9, change: 0.04894716\n",
      "Epoch 10, change: 0.05334482\n",
      "Epoch 11, change: 0.04136777\n",
      "Epoch 12, change: 0.03769561\n",
      "Epoch 13, change: 0.03564484\n",
      "Epoch 14, change: 0.03316160\n",
      "Epoch 15, change: 0.03280906\n",
      "Epoch 16, change: 0.03160327\n",
      "Epoch 17, change: 0.02771494\n",
      "Epoch 18, change: 0.02647798\n",
      "Epoch 19, change: 0.02446191\n",
      "Epoch 20, change: 0.02382095\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.20762670\n",
      "Epoch 3, change: 0.12347144\n",
      "Epoch 4, change: 0.08534203\n",
      "Epoch 5, change: 0.09035747\n",
      "Epoch 6, change: 0.06372764\n",
      "Epoch 7, change: 0.05636129\n",
      "Epoch 8, change: 0.05672605\n",
      "Epoch 9, change: 0.05178861\n",
      "Epoch 10, change: 0.04730889\n",
      "Epoch 11, change: 0.04456064\n",
      "Epoch 12, change: 0.04051116\n",
      "Epoch 13, change: 0.03681107\n",
      "Epoch 14, change: 0.03760772\n",
      "Epoch 15, change: 0.03285763\n",
      "Epoch 16, change: 0.03324739\n",
      "Epoch 17, change: 0.02809284\n",
      "Epoch 18, change: 0.02543820\n",
      "Epoch 19, change: 0.02494172\n",
      "Epoch 20, change: 0.02483169\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.11991099\n",
      "Epoch 3, change: 0.08978162\n",
      "Epoch 4, change: 0.07652599\n",
      "Epoch 5, change: 0.06739399\n",
      "Epoch 6, change: 0.06329168\n",
      "Epoch 7, change: 0.05669705\n",
      "Epoch 8, change: 0.05699285\n",
      "Epoch 9, change: 0.04935531\n",
      "Epoch 10, change: 0.04903838\n",
      "Epoch 11, change: 0.04545763\n",
      "Epoch 12, change: 0.04438295\n",
      "Epoch 13, change: 0.03831228\n",
      "Epoch 14, change: 0.03705936\n",
      "Epoch 15, change: 0.03556893\n",
      "Epoch 16, change: 0.03226724\n",
      "Epoch 17, change: 0.03157677\n",
      "Epoch 18, change: 0.03036124\n",
      "Epoch 19, change: 0.02684650\n",
      "Epoch 20, change: 0.02235683\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.18492664\n",
      "Epoch 3, change: 0.17254369\n",
      "Epoch 4, change: 0.08829578\n",
      "Epoch 5, change: 0.06093951\n",
      "Epoch 6, change: 0.04310906\n",
      "Epoch 7, change: 0.03301736\n",
      "Epoch 8, change: 0.02451215\n",
      "Epoch 9, change: 0.02343789\n",
      "Epoch 10, change: 0.02238654\n",
      "Epoch 11, change: 0.02163447\n",
      "Epoch 12, change: 0.02114363\n",
      "Epoch 13, change: 0.02081967\n",
      "Epoch 14, change: 0.02045417\n",
      "Epoch 15, change: 0.01992527\n",
      "Epoch 16, change: 0.01988873\n",
      "Epoch 17, change: 0.01959810\n",
      "Epoch 18, change: 0.01902003\n",
      "Epoch 19, change: 0.01909974\n",
      "Epoch 20, change: 0.01879367\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.25278366\n",
      "Epoch 3, change: 0.16525163\n",
      "Epoch 4, change: max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.12267846\n",
      "Epoch 5, change: 0.10972810\n",
      "Epoch 6, change: 0.09343598\n",
      "Epoch 7, change: 0.08710282\n",
      "Epoch 8, change: 0.07474903\n",
      "Epoch 9, change: 0.06740711\n",
      "Epoch 10, change: 0.06171302\n",
      "Epoch 11, change: 0.05267249\n",
      "Epoch 12, change: 0.05137428\n",
      "Epoch 13, change: 0.04649624\n",
      "Epoch 14, change: 0.04018385\n",
      "Epoch 15, change: 0.03854440\n",
      "Epoch 16, change: 0.03856120\n",
      "Epoch 17, change: 0.03649823\n",
      "Epoch 18, change: 0.03176294\n",
      "Epoch 19, change: 0.02880062\n",
      "Epoch 20, change: 0.02802260\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.26896000\n",
      "Epoch 3, change: 0.19621013\n",
      "Epoch 4, change: 0.13524981\n",
      "Epoch 5, change: 0.11366735\n",
      "Epoch 6, change: 0.09131941\n",
      "Epoch 7, change: 0.08229771\n",
      "Epoch 8, change: 0.06945738\n",
      "Epoch 9, change: 0.05756110\n",
      "Epoch 10, change: 0.05170762\n",
      "Epoch 11, change: 0.04537066\n",
      "Epoch 12, change: 0.04126325\n",
      "Epoch 13, change: 0.04041053\n",
      "Epoch 14, change: 0.03520043\n",
      "Epoch 15, change: 0.03241969\n",
      "Epoch 16, change: 0.03121885\n",
      "Epoch 17, change: 0.02986313\n",
      "Epoch 18, change: 0.02713620\n",
      "Epoch 19, change: 0.02416928\n",
      "Epoch 20, change: 0.02362882\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.26919881\n",
      "Epoch 3, change: 0.18483806\n",
      "Epoch 4, change: 0.13911559\n",
      "Epoch 5, change: 0.12134190\n",
      "Epoch 6, change: 0.09348058\n",
      "Epoch 7, change: 0.07003859\n",
      "Epoch 8, change: 0.06592509\n",
      "Epoch 9, change: 0.05601830\n",
      "Epoch 10, change: 0.05303227\n",
      "Epoch 11, change: 0.04942447\n",
      "Epoch 12, change: 0.04027031\n",
      "Epoch 13, change: 0.04049996\n",
      "Epoch 14, change: 0.03802933\n",
      "Epoch 15, change: 0.03466126\n",
      "Epoch 16, change: 0.02913525\n",
      "Epoch 17, change: 0.02861009\n",
      "Epoch 18, change: 0.02834597\n",
      "Epoch 19, change: 0.02633496\n",
      "Epoch 20, change: 0.02614886\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.32185960\n",
      "Epoch 3, change: 0.22689125\n",
      "Epoch 4, change: 0.13538159\n",
      "Epoch 5, change: 0.11462823\n",
      "Epoch 6, change: 0.08271781\n",
      "Epoch 7, change: 0.06462748\n",
      "Epoch 8, change: 0.05933930\n",
      "Epoch 9, change: 0.04917655\n",
      "Epoch 10, change: 0.04312444\n",
      "Epoch 11, change: 0.04278922\n",
      "Epoch 12, change: 0.03995549\n",
      "Epoch 13, change: 0.03608126\n",
      "Epoch 14, change: 0.03285840\n",
      "Epoch 15, change: 0.03220885\n",
      "Epoch 16, change: 0.03140189\n",
      "Epoch 17, change: 0.03063084\n",
      "Epoch 18, change: 0.02950541\n",
      "Epoch 19, change: 0.02872794\n",
      "Epoch 20, change: 0.02843225\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.30677977\n",
      "Epoch 3, change: 0.18229143\n",
      "Epoch 4, change: 0.14083274\n",
      "Epoch 5, change: 0.09872343\n",
      "Epoch 6, change: 0.07096925\n",
      "Epoch 7, change: 0.07386824\n",
      "Epoch 8, change: 0.07009234\n",
      "Epoch 9, change: 0.06480679\n",
      "Epoch 10, change: 0.05230926\n",
      "Epoch 11, change: 0.04513287\n",
      "Epoch 12, change: 0.04390407\n",
      "Epoch 13, change: 0.03571774\n",
      "Epoch 14, change: 0.03407313\n",
      "Epoch 15, change: 0.03458934\n",
      "Epoch 16, change: 0.03380796\n",
      "Epoch 17, change: 0.02924647\n",
      "Epoch 18, change: 0.02853325\n",
      "Epoch 19, change: 0.02514124\n",
      "Epoch 20, change: 0.02385961\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.10499111\n",
      "Epoch 3, change: 0.08475672\n",
      "Epoch 4, change: 0.06834457\n",
      "Epoch 5, change: 0.08031294\n",
      "Epoch 6, change: 0.06514052\n",
      "Epoch 7, change: 0.05843407\n",
      "Epoch 8, change: 0.05724125\n",
      "Epoch 9, change: 0.05138730\n",
      "Epoch 10, change: 0.04774524\n",
      "Epoch 11, change: 0.04659243\n",
      "Epoch 12, change: 0.04526709\n",
      "Epoch 13, change: 0.04393787\n",
      "Epoch 14, change: 0.04367122\n",
      "Epoch 15, change: 0.04356769\n",
      "Epoch 16, change: 0.03989852\n",
      "Epoch 17, change: 0.03988143\n",
      "Epoch 18, change: 0.04196526\n",
      "Epoch 19, change: 0.04085970\n",
      "Epoch 20, change: 0.03099308\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.28099248\n",
      "Epoch 3, change: 0.17714298\n",
      "Epoch 4, change: 0.11168522\n",
      "Epoch 5, change: 0.10062396\n",
      "Epoch 6, change: 0.07803569\n",
      "Epoch 7, change: 0.06302588\n",
      "Epoch 8, change: 0.05894641\n",
      "Epoch 9, change: 0.04762200\n",
      "Epoch 10, change: 0.04100873\n",
      "Epoch 11, change: 0.03838533\n",
      "Epoch 12, change: 0.03857714\n",
      "Epoch 13, change: 0.03355967\n",
      "Epoch 14, change: 0.03182422\n",
      "Epoch 15, change: 0.02970864\n",
      "Epoch 16, change: 0.02800906\n",
      "Epoch 17, change: 0.02695866\n",
      "Epoch 18, change: 0.02652252\n",
      "Epoch 19, change: 0.02605413\n",
      "Epoch 20, change: 0.02544156\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.25935626\n",
      "Epoch 3, change: 0.16024515\n",
      "Epoch 4, change: 0.11582939\n",
      "Epoch 5, change: 0.09818604\n",
      "Epoch 6, change: 0.07553544\n",
      "Epoch 7, change: 0.06874735\n",
      "Epocmax_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "h 8, change: 0.06800864\n",
      "Epoch 9, change: 0.05838026\n",
      "Epoch 10, change: 0.05338601\n",
      "Epoch 11, change: 0.05023703\n",
      "Epoch 12, change: 0.04802536\n",
      "Epoch 13, change: 0.04310314\n",
      "Epoch 14, change: 0.04007102\n",
      "Epoch 15, change: 0.03960499\n",
      "Epoch 16, change: 0.03819723\n",
      "Epoch 17, change: 0.03296255\n",
      "Epoch 18, change: 0.03310589\n",
      "Epoch 19, change: 0.03139027\n",
      "Epoch 20, change: 0.03109676\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.29554102\n",
      "Epoch 3, change: 0.15456241\n",
      "Epoch 4, change: 0.09457746\n",
      "Epoch 5, change: 0.08627687\n",
      "Epoch 6, change: 0.08078042\n",
      "Epoch 7, change: 0.07328122\n",
      "Epoch 8, change: 0.05680581\n",
      "Epoch 9, change: 0.05909497\n",
      "Epoch 10, change: 0.04631816\n",
      "Epoch 11, change: 0.04558051\n",
      "Epoch 12, change: 0.04117604\n",
      "Epoch 13, change: 0.03846572\n",
      "Epoch 14, change: 0.03545265\n",
      "Epoch 15, change: 0.03534595\n",
      "Epoch 16, change: 0.03468075\n",
      "Epoch 17, change: 0.03045679\n",
      "Epoch 18, change: 0.02509518\n",
      "Epoch 19, change: 0.02416522\n",
      "Epoch 20, change: 0.02210655\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.26574019\n",
      "Epoch 3, change: 0.22899932\n",
      "Epoch 4, change: 0.12391909\n",
      "Epoch 5, change: 0.10076319\n",
      "Epoch 6, change: 0.08720949\n",
      "Epoch 7, change: 0.07073999\n",
      "Epoch 8, change: 0.06663395\n",
      "Epoch 9, change: 0.05233923\n",
      "Epoch 10, change: 0.04840617\n",
      "Epoch 11, change: 0.04697825\n",
      "Epoch 12, change: 0.04071569\n",
      "Epoch 13, change: 0.03580577\n",
      "Epoch 14, change: 0.03325501\n",
      "Epoch 15, change: 0.03359709\n",
      "Epoch 16, change: 0.02976418\n",
      "Epoch 17, change: 0.02576557\n",
      "Epoch 18, change: 0.02433221\n",
      "Epoch 19, change: 0.02335152\n",
      "Epoch 20, change: 0.02299741\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.23751622\n",
      "Epoch 3, change: 0.13500243\n",
      "Epoch 4, change: 0.10559546\n",
      "Epoch 5, change: 0.08054624\n",
      "Epoch 6, change: 0.06377900\n",
      "Epoch 7, change: 0.06988524\n",
      "Epoch 8, change: 0.05740873\n",
      "Epoch 9, change: 0.04716877\n",
      "Epoch 10, change: 0.03927939\n",
      "Epoch 11, change: 0.03645024\n",
      "Epoch 12, change: 0.03221225\n",
      "Epoch 13, change: 0.03075424\n",
      "Epoch 14, change: 0.02837575\n",
      "Epoch 15, change: 0.02707563\n",
      "Epoch 16, change: 0.02446559\n",
      "Epoch 17, change: 0.02139823\n",
      "Epoch 18, change: 0.02028402\n",
      "Epoch 19, change: 0.01978609\n",
      "Epoch 20, change: 0.01886747\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.20612961\n",
      "Epoch 3, change: 0.14704080\n",
      "Epoch 4, change: 0.11746798\n",
      "Epoch 5, change: 0.09363024\n",
      "Epoch 6, change: 0.07841414\n",
      "Epoch 7, change: 0.06752630\n",
      "Epoch 8, change: 0.05502224\n",
      "Epoch 9, change: 0.04819372\n",
      "Epoch 10, change: 0.04313770\n",
      "Epoch 11, change: 0.03975655\n",
      "Epoch 12, change: 0.03393771\n",
      "Epoch 13, change: 0.03141496\n",
      "Epoch 14, change: 0.03125613\n",
      "Epoch 15, change: 0.02921430\n",
      "Epoch 16, change: 0.02791670\n",
      "Epoch 17, change: 0.02307738\n",
      "Epoch 18, change: 0.02358721\n",
      "Epoch 19, change: 0.02242805\n",
      "Epoch 20, change: 0.02120661\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.23227775\n",
      "Epoch 3, change: 0.15044039\n",
      "Epoch 4, change: 0.11086776\n",
      "Epoch 5, change: 0.09896626\n",
      "Epoch 6, change: 0.08035759\n",
      "Epoch 7, change: 0.06660448\n",
      "Epoch 8, change: 0.05950093\n",
      "Epoch 9, change: 0.05170196\n",
      "Epoch 10, change: 0.04820872\n",
      "Epoch 11, change: 0.04310622\n",
      "Epoch 12, change: 0.04266378\n",
      "Epoch 13, change: 0.03623091\n",
      "Epoch 14, change: 0.03457338\n",
      "Epoch 15, change: 0.03111645\n",
      "Epoch 16, change: 0.02869010\n",
      "Epoch 17, change: 0.02709116\n",
      "Epoch 18, change: 0.02551249\n",
      "Epoch 19, change: 0.02471405\n",
      "Epoch 20, change: 0.02496568\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.24972712\n",
      "Epoch 3, change: 0.17422092\n",
      "Epoch 4, change: 0.14302820\n",
      "Epoch 5, change: 0.10912494\n",
      "Epoch 6, change: 0.09072451\n",
      "Epoch 7, change: 0.07439022\n",
      "Epoch 8, change: 0.06768478\n",
      "Epoch 9, change: 0.05562280\n",
      "Epoch 10, change: 0.04715926\n",
      "Epoch 11, change: 0.04489271\n",
      "Epoch 12, change: 0.03690534\n",
      "Epoch 13, change: 0.03705003\n",
      "Epoch 14, change: 0.03104443\n",
      "Epoch 15, change: 0.02851199\n",
      "Epoch 16, change: 0.02776177\n",
      "Epoch 17, change: 0.02540681\n",
      "Epoch 18, change: 0.02615576\n",
      "Epoch 19, change: 0.02259930\n",
      "Epoch 20, change: 0.01985866\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.15266170\n",
      "Epoch 3, change: 0.09599254\n",
      "Epoch 4, change: 0.05828709\n",
      "Epoch 5, change: 0.04911096\n",
      "Epoch 6, change: 0.04345327\n",
      "Epoch 7, change: 0.03825685\n",
      "Epoch 8, change: 0.03694197\n",
      "Epoch 9, change: 0.03179735\n",
      "Epoch 10, change: 0.02828100\n",
      "Epoch 11, change: max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.3s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.02516605\n",
      "Epoch 12, change: 0.02401905\n",
      "Epoch 13, change: 0.02233984\n",
      "Epoch 14, change: 0.02204493\n",
      "Epoch 15, change: 0.02255839\n",
      "Epoch 16, change: 0.02136740\n",
      "Epoch 17, change: 0.02055626\n",
      "Epoch 18, change: 0.01761787\n",
      "Epoch 19, change: 0.01826754\n",
      "Epoch 20, change: 0.01735444\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.31534991\n",
      "Epoch 3, change: 0.17116573\n",
      "Epoch 4, change: 0.13716325\n",
      "Epoch 5, change: 0.09618021\n",
      "Epoch 6, change: 0.06730897\n",
      "Epoch 7, change: 0.05938310\n",
      "Epoch 8, change: 0.04508926\n",
      "Epoch 9, change: 0.04190752\n",
      "Epoch 10, change: 0.03077974\n",
      "Epoch 11, change: 0.02841954\n",
      "Epoch 12, change: 0.02620609\n",
      "Epoch 13, change: 0.02271299\n",
      "Epoch 14, change: 0.02467970\n",
      "Epoch 15, change: 0.01998835\n",
      "Epoch 16, change: 0.01714212\n",
      "Epoch 17, change: 0.01706825\n",
      "Epoch 18, change: 0.01559709\n",
      "Epoch 19, change: 0.01442908\n",
      "Epoch 20, change: 0.01424436\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.26308036\n",
      "Epoch 3, change: 0.14570166\n",
      "Epoch 4, change: 0.11579746\n",
      "Epoch 5, change: 0.09348176\n",
      "Epoch 6, change: 0.08544495\n",
      "Epoch 7, change: 0.07625514\n",
      "Epoch 8, change: 0.07035332\n",
      "Epoch 9, change: 0.06259672\n",
      "Epoch 10, change: 0.05501523\n",
      "Epoch 11, change: 0.04852444\n",
      "Epoch 12, change: 0.04417152\n",
      "Epoch 13, change: 0.04121453\n",
      "Epoch 14, change: 0.03713568\n",
      "Epoch 15, change: 0.03419761\n",
      "Epoch 16, change: 0.03216071\n",
      "Epoch 17, change: 0.03140821\n",
      "Epoch 18, change: 0.02930530\n",
      "Epoch 19, change: 0.02451167\n",
      "Epoch 20, change: 0.02395559\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.26852673\n",
      "Epoch 3, change: 0.14572456\n",
      "Epoch 4, change: 0.12052780\n",
      "Epoch 5, change: 0.09193307\n",
      "Epoch 6, change: 0.08078162\n",
      "Epoch 7, change: 0.06483574\n",
      "Epoch 8, change: 0.05733146\n",
      "Epoch 9, change: 0.05308044\n",
      "Epoch 10, change: 0.04501369\n",
      "Epoch 11, change: 0.04322927\n",
      "Epoch 12, change: 0.04056388\n",
      "Epoch 13, change: 0.03727151\n",
      "Epoch 14, change: 0.03426178\n",
      "Epoch 15, change: 0.03152880\n",
      "Epoch 16, change: 0.02713562\n",
      "Epoch 17, change: 0.02586028\n",
      "Epoch 18, change: 0.02524194\n",
      "Epoch 19, change: 0.02346481\n",
      "Epoch 20, change: 0.02286536\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.16469547\n",
      "Epoch 3, change: 0.13928396\n",
      "Epoch 4, change: 0.10251690\n",
      "Epoch 5, change: 0.09807996\n",
      "Epoch 6, change: 0.09321561\n",
      "Epoch 7, change: 0.07580203\n",
      "Epoch 8, change: 0.07444584\n",
      "Epoch 9, change: 0.06084797\n",
      "Epoch 10, change: 0.05625969\n",
      "Epoch 11, change: 0.04814595\n",
      "Epoch 12, change: 0.04571138\n",
      "Epoch 13, change: 0.04317028\n",
      "Epoch 14, change: 0.03748411\n",
      "Epoch 15, change: 0.03363065\n",
      "Epoch 16, change: 0.03257611\n",
      "Epoch 17, change: 0.03103237\n",
      "Epoch 18, change: 0.03008093\n",
      "Epoch 19, change: 0.02917248\n",
      "Epoch 20, change: 0.02623053\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.14965802\n",
      "Epoch 3, change: 0.21865363\n",
      "Epoch 4, change: 0.07498015\n",
      "Epoch 5, change: 0.06226104\n",
      "Epoch 6, change: 0.06940006\n",
      "Epoch 7, change: 0.05574093\n",
      "Epoch 8, change: 0.05088939\n",
      "Epoch 9, change: 0.04675990\n",
      "Epoch 10, change: 0.04144712\n",
      "Epoch 11, change: 0.03814123\n",
      "Epoch 12, change: 0.03942867\n",
      "Epoch 13, change: 0.03392496\n",
      "Epoch 14, change: 0.03215348\n",
      "Epoch 15, change: 0.03058406\n",
      "Epoch 16, change: 0.02872809\n",
      "Epoch 17, change: 0.02891118\n",
      "Epoch 18, change: 0.02873023\n",
      "Epoch 19, change: 0.02628154\n",
      "Epoch 20, change: 0.02438173\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.04722588\n",
      "Epoch 3, change: 0.03711120\n",
      "Epoch 4, change: 0.03377127\n",
      "Epoch 5, change: 0.03651736\n",
      "Epoch 6, change: 0.03658089\n",
      "Epoch 7, change: 0.03729797\n",
      "Epoch 8, change: 0.03845207\n",
      "Epoch 9, change: 0.03952068\n",
      "Epoch 10, change: 0.03363056\n",
      "Epoch 11, change: 0.03106654\n",
      "Epoch 12, change: 0.03402508\n",
      "Epoch 13, change: 0.03100311\n",
      "Epoch 14, change: 0.03032618\n",
      "Epoch 15, change: 0.03688557\n",
      "Epoch 16, change: 0.02692729\n",
      "Epoch 17, change: 0.02564827\n",
      "Epoch 18, change: 0.03143947\n",
      "Epoch 19, change: 0.01499585\n",
      "Epoch 20, change: 0.01822113\n",
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.13335676\n",
      "Epoch 3, change: 0.05742165\n",
      "Epoch 4, change: 0.05810973\n",
      "Epoch 5, change: 0.03939699\n",
      "Epoch 6, change: 0.03609814\n",
      "Epoch 7, change: 0.03462229\n",
      "Epoch 8, change: 0.03366181\n",
      "Epoch 9, change: 0.03188349\n",
      "Epoch 10, change: 0.02999078\n",
      "Epoch 11, change: 0.02945834\n",
      "Epoch 12, change: 0.02771664\n",
      "Epoch 13, change: 0.02816360\n",
      "Epoch 14, change: 0.02645437\n",
      "Emax_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n",
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 15 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   15.4s finished\n"
     ]
    }
   ],
   "source": [
    "z_train3 = new_Z(x_train, z_train2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5cf89f8c-d227-4f9f-bc87-be1df0e25258",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "V = get_PMS(x_train, z_train3, 0, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c62d1aa0-4c1e-482e-86e8-466f75e30933",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "V_list = V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "557fe151-8c54-4533-9718-86133098c2cf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):\n",
    "    \"\"\"\n",
    "    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),\n",
    "    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.\n",
    "    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:\n",
    "    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))\n",
    "    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections\n",
    "    :param dim: input dim\n",
    "    \"\"\"\n",
    "    I = np.eye(input_dim)\n",
    "    Q = np.sum(rowspace_projection_matrices, axis = 0)\n",
    "    P = I - get_rowspace_projection(Q)\n",
    "    return P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "78120f92-dde9-4128-a004-21e8597b05a2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def proj(pp_list, q_list):\n",
    "    gg_list = []\n",
    "    for i in range(len(pp_list)):\n",
    "        gg_list.append(get_rowspace_projection(pp_list[i][:,:int(q_list[i])].T))\n",
    "    return get_projection_to_intersection_of_nullspaces(gg_list, 768)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "0759b142-79dc-4165-bcf1-7711f38421a8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "q_vec = np.array([25, 19, 19, 8, 6, 10, 26, 0, 24, 3, 17, 8, 25, 3, 7, 11, 13, 20, 9, 26, 9, 15, 22, 6, 25, 10, 23, 24, 12, 2, 1, 14, 0, 24, 14, 4, 18, 5, 1, 5, 4, 25, 15, 18 ,15, 9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e0b6ec31-86eb-49cd-bdf5-40c7aee82ab5",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 5 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:    5.3s finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3900918635170604\n",
      "5.951626389105669 0.948433871068985 0.9496786578210833\n"
     ]
    }
   ],
   "source": [
    "pp = proj(V_list, q_vec)\n",
    "\n",
    "lr2 = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                     solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                     verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "lr2.fit(np.dot(x_train,pp), y_train)\n",
    "\n",
    "accuracy = lr2.score(np.dot(x_test,pp), y_test)\n",
    "\n",
    "Y_pred = lr2.predict(np.dot(x_test,pp))\n",
    "\n",
    "rms_diff_vec = GAPs(Y_pred, y_test, z_test)\n",
    "\n",
    "y = y_test.reshape(-1)\n",
    "pred = lr2.predict_proba(np.dot(x_test,pp))[:,1].reshape(-1)\n",
    "fpr, tpr, thresholds = metrics.roc_curve(y, pred)\n",
    "AUC = metrics.auc(fpr, tpr)\n",
    "\n",
    "print(rms_diff_vec, accuracy, AUC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34c07e24-01a0-4beb-bf4e-e0db90dc1e23",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08c0fae2-e8f6-4143-95bc-2298bd4ab318",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
