{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "954ea015-8a54-4b9e-b3a0-cceb90817e84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "322636\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy, requests, codecs, os, re, nltk, itertools, csv\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "import tensorflow as tf\n",
    "from scipy.stats import spearmanr\n",
    "import functools as ft\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import normalize\n",
    "import gdown\n",
    "\n",
    "\n",
    "file_name = \"glove_wiki_vectors.txt\"\n",
    "VEC_LEN = 300\n",
    "glove_file = open(\"../glove_wiki_vectors.txt\", 'r')\n",
    "glove_word = {}\n",
    "for line in glove_file:\n",
    "    line = line.strip()\n",
    "    _word = line.split(' ')\n",
    "    vector = np.array([float(num) for num in _word[1:]])\n",
    "    if len(vector) != VEC_LEN: \n",
    "        raise Exception(\"Word dimension is wrong\")\n",
    "    glove_word[_word[0]] = vector\n",
    "glove_file.close()\n",
    "print(len(glove_word))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "57d6a359-760f-454f-9d18-c3a79d7e2979",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import gensim\n",
    "import codecs\n",
    "import json\n",
    "from gensim.models.keyedvectors import Word2VecKeyedVectors\n",
    "from gensim.models import KeyedVectors\n",
    "import numpy as np\n",
    "import random\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from sklearn import cluster\n",
    "from sklearn import metrics\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.svm import LinearSVC, SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "import numpy as np\n",
    "from docopt import docopt\n",
    "import torch\n",
    "from transformers import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy\n",
    "from scipy import linalg\n",
    "from scipy import sparse\n",
    "from scipy.stats.stats import pearsonr\n",
    "import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "from torch import utils\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import time\n",
    "from mlxtend.math import vectorspace_orthonormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3a22c7c2-d8ce-48a2-8df6-6b7035116d13",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras.utils import np_utils\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.pipeline import Pipeline\n",
    "import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd3d464b-8627-4f39-b126-fedb2a46d946",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1.245750e-01, -4.009240e-01,  6.271400e-02,  4.111900e-01,\n",
       "       -8.816970e-01,  2.350280e-01,  1.360400e-01, -2.392700e-02,\n",
       "        1.160930e-01, -1.206110e-01, -1.762200e-02,  2.154890e-01,\n",
       "        1.726310e-01, -4.240420e-01,  2.806960e-01, -1.301440e-01,\n",
       "       -2.338660e-01,  4.867020e-01,  7.754460e-01, -2.059050e-01,\n",
       "        3.617610e-01,  4.187000e-02, -2.559800e-01, -4.998360e-01,\n",
       "       -1.035480e-01,  6.258390e-01, -2.862690e-01, -3.280910e-01,\n",
       "       -2.078430e-01,  3.054890e-01,  1.082030e-01, -3.312780e-01,\n",
       "       -2.726420e-01,  1.926860e-01,  5.072300e-02,  3.893320e-01,\n",
       "        2.933890e-01, -4.360580e-01, -3.010200e-02,  3.158700e-01,\n",
       "       -3.962860e-01,  4.540110e-01,  7.849120e-01,  6.861510e-01,\n",
       "       -6.444350e-01,  2.352310e-01, -5.320800e-02, -2.613060e-01,\n",
       "        1.339090e-01, -2.140570e-01,  1.470220e-01,  1.431780e-01,\n",
       "       -1.402500e-01,  2.238600e-02, -6.906300e-02, -2.349210e-01,\n",
       "        4.022160e-01, -1.911420e-01, -2.725900e-01,  8.616200e-02,\n",
       "        1.466570e-01, -1.760270e-01, -6.033100e-01,  1.708400e-02,\n",
       "        2.039600e-02,  7.842400e-02, -4.272000e-03,  3.163090e-01,\n",
       "       -3.120360e-01,  1.026717e+00,  1.112710e-01,  9.714800e-02,\n",
       "        2.597710e-01,  4.699100e-01,  6.015900e-02,  1.760000e-04,\n",
       "        1.428810e-01,  2.050910e-01, -4.586080e-01, -7.705200e-02,\n",
       "        1.461880e-01, -2.274320e-01,  1.223067e+00,  2.363410e-01,\n",
       "       -2.824230e-01,  9.096000e-02, -1.195820e-01,  3.237190e-01,\n",
       "        4.136700e-02,  9.524600e-02,  2.431990e-01,  6.146060e-01,\n",
       "        2.866200e-02,  4.417430e-01,  1.592440e-01,  3.938820e-01,\n",
       "       -4.809000e-01,  3.291680e-01,  5.334500e-02,  1.704500e-01,\n",
       "        3.302640e-01, -1.594420e-01,  5.438560e-01,  3.328100e-02,\n",
       "        3.474750e-01,  3.214290e-01,  2.622150e-01, -1.878060e-01,\n",
       "       -5.873710e-01, -1.328790e-01, -3.773310e-01,  7.291460e-01,\n",
       "       -1.799160e-01,  4.147620e-01,  5.010880e-01, -9.682800e-02,\n",
       "        1.615120e-01, -1.357960e-01, -3.380400e-01, -6.110700e-02,\n",
       "       -2.563300e-02, -3.925900e-02, -2.155800e-01, -1.598960e-01,\n",
       "       -2.708180e-01, -1.268510e-01, -5.768100e-02,  2.152360e-01,\n",
       "        1.781310e-01,  1.894660e-01, -2.774290e-01, -4.530830e-01,\n",
       "       -2.279000e-02,  3.099390e-01,  2.470090e-01, -1.350130e-01,\n",
       "        1.947490e-01, -2.476300e-02, -2.600710e-01, -2.648480e-01,\n",
       "       -7.398700e-02,  7.080900e-02,  2.453920e-01, -6.196310e-01,\n",
       "       -1.910290e-01, -7.923700e-02, -1.796830e-01, -9.178000e-03,\n",
       "       -3.727910e-01,  3.758540e-01,  2.170070e-01,  3.655080e-01,\n",
       "        7.891900e-02,  2.987970e-01,  3.976790e-01, -2.104860e-01,\n",
       "        4.083240e-01,  3.031760e-01,  2.688640e-01,  6.663300e-02,\n",
       "       -2.820050e-01, -1.843410e-01,  2.332610e-01, -4.863000e-03,\n",
       "        1.447080e-01, -9.031200e-02, -2.512840e-01,  2.355370e-01,\n",
       "        9.874000e-03,  3.145620e-01,  9.107700e-02, -2.474430e-01,\n",
       "       -1.460010e-01, -1.526980e-01,  4.654380e-01, -4.790490e-01,\n",
       "       -2.707420e-01, -1.189260e-01,  1.046150e-01, -2.011160e-01,\n",
       "        3.900410e-01,  2.937310e-01, -5.070000e-04, -1.001980e-01,\n",
       "       -2.752070e-01, -3.678833e+00,  8.109250e-01, -9.212800e-02,\n",
       "       -1.853410e-01, -2.419600e-01, -1.730690e-01, -1.664830e-01,\n",
       "        2.310580e-01, -1.280820e-01,  4.042000e-03, -1.109530e-01,\n",
       "        4.494920e-01, -1.000000e-01,  7.992210e-01, -8.177100e-01,\n",
       "       -5.375000e-03, -3.172290e-01,  6.082460e-01,  2.230090e-01,\n",
       "        5.591240e-01,  2.868070e-01, -2.684860e-01,  2.731700e-02,\n",
       "        2.052650e-01,  2.661490e-01,  3.134270e-01, -6.906600e-02,\n",
       "       -3.775420e-01, -4.094080e-01,  2.992300e-02,  3.021800e-02,\n",
       "        1.407120e-01,  9.994700e-02,  2.832260e-01, -2.871340e-01,\n",
       "        4.863310e-01, -2.805150e-01, -9.746900e-01,  1.256730e-01,\n",
       "       -2.355250e-01, -4.555900e-02,  1.550870e-01, -6.355000e-03,\n",
       "       -1.402990e-01, -1.191180e-01,  1.452540e-01, -1.914720e-01,\n",
       "        1.165930e-01,  3.597150e-01, -3.949600e-01,  1.556840e-01,\n",
       "        5.246460e-01,  1.562470e-01, -3.804140e-01,  3.600060e-01,\n",
       "        4.260870e-01,  1.673060e-01, -4.975800e-02,  6.811950e-01,\n",
       "        6.392730e-01, -1.238390e-01,  2.491530e-01, -5.067020e-01,\n",
       "        1.320740e-01,  3.338180e-01,  4.922470e-01,  1.305930e-01,\n",
       "       -4.405410e-01,  1.539660e-01,  4.150670e-01,  6.232110e-01,\n",
       "       -2.471900e-01,  1.381830e-01,  3.231810e-01, -5.452000e-03,\n",
       "       -8.212830e-01, -4.289580e-01,  1.191760e-01, -3.325380e-01,\n",
       "       -2.178680e-01,  1.579350e-01,  6.260450e-01,  6.973200e-02,\n",
       "        6.005940e-01,  3.533650e-01, -3.129010e-01, -2.923500e-02,\n",
       "        1.589280e-01,  4.170780e-01, -2.250080e-01, -4.019380e-01,\n",
       "       -3.734570e-01, -2.132290e-01,  1.480550e-01, -2.328940e-01,\n",
       "        8.839900e-02, -1.510390e-01, -6.121740e-01,  3.113400e-02,\n",
       "       -3.863580e-01,  4.258630e-01,  4.743820e-01, -2.695040e-01,\n",
       "       -1.885540e-01,  1.429960e-01, -7.470900e-02,  1.292790e-01,\n",
       "       -2.822310e-01,  1.664800e-01, -2.380640e-01,  7.218100e-02,\n",
       "       -7.023700e-02, -4.396470e-01, -8.029300e-02, -2.783630e-01])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove_word['he']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b80ece97-5840-4f23-abf4-38a205c16c50",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_words = list(glove_word.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7c7d2e22-0ba3-41bb-a8b8-d1c08cec83e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "vec = glove_word['he'] - glove_word['she']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6385364c-266e-4ec1-b2d1-844b1ea5938d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(columns = ['word','sim'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4aff7e76-bf32-48c0-9667-fb6d94951c0e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.where(np.array([1,2,2])==2)[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "20de65a4-ef21-4507-b7f6-8cf5537bf68a",
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b71aaaa9-978c-45ef-95c6-6a06d7727754",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cosine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c7f0d1bd-01ed-4779-80af-272ecea2d826",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for word in all_words:\n",
    "    similarity.append(1-cosine(vec, glove_word[word]))\n",
    "    #df.loc[len(df.index)] = [word, similarity]\n",
    "    #print(df.values.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "68b319e5-529e-4328-a123-256d11c05dbd",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "df['word'] = all_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "37e47e11-f8b8-434f-84de-c59580a155b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['sim'] = similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7b6605d2-af2a-4c3d-afad-4e09866482a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.sort_values(by=\"sim\" , ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6945558b-ce41-45fd-a5a5-26d863fc0279",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>word</th>\n",
       "      <th>sim</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>51069</th>\n",
       "      <td>benefices</td>\n",
       "      <td>0.285936</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54251</th>\n",
       "      <td>outfielders</td>\n",
       "      <td>0.284975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28859</th>\n",
       "      <td>suspensions</td>\n",
       "      <td>0.262140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>62175</th>\n",
       "      <td>uppingham</td>\n",
       "      <td>0.259404</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7769</th>\n",
       "      <td>bulls</td>\n",
       "      <td>0.257339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1576</th>\n",
       "      <td>actress</td>\n",
       "      <td>-0.403790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199445</th>\n",
       "      <td>wittum</td>\n",
       "      <td>-0.407194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>her</td>\n",
       "      <td>-0.418087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>246652</th>\n",
       "      <td>wifely</td>\n",
       "      <td>-0.420689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>she</td>\n",
       "      <td>-0.427518</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>322636 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "               word       sim\n",
       "51069     benefices  0.285936\n",
       "54251   outfielders  0.284975\n",
       "28859   suspensions  0.262140\n",
       "62175     uppingham  0.259404\n",
       "7769          bulls  0.257339\n",
       "...             ...       ...\n",
       "1576        actress -0.403790\n",
       "199445       wittum -0.407194\n",
       "39              her -0.418087\n",
       "246652       wifely -0.420689\n",
       "40              she -0.427518\n",
       "\n",
       "[322636 rows x 2 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a4bf5d26-2ad5-46f1-af0f-a56e30ba005d",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 9500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "90337823-90a9-42c8-8589-34c7c4e258cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "male_words = list(df['word'][:N])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "52f6f631-96cc-44b9-8aaa-8e8786c5f821",
   "metadata": {},
   "outputs": [],
   "source": [
    "female_words = list(df['word'][-N:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "fb015ae3-d5f5-4e4a-ad0b-4aa473428d9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = []\n",
    "Y_train = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "41854784-c28f-42d0-824f-426113c7f920",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(N):\n",
    "    X_train.append(glove_word[male_words[i]])\n",
    "    Y_train.append(1)\n",
    "    X_train.append(glove_word[female_words[i]])\n",
    "    Y_train.append(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "13516df9-ca7d-4752-aa66-5720119365d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.array(X_train)\n",
    "Y_train = np.array(Y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "7da0dd95-acf4-4663-a22f-e1014d0b1b76",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 1 seconds\n",
      "time: 0.5868015289306641\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:    0.6s finished\n"
     ]
    }
   ],
   "source": [
    "clf = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "\n",
    "        \n",
    "start = time.time()\n",
    "clf.fit(X_train, Y_train)\n",
    "print(\"time: {}\".format(time.time() - start))\n",
    "print(clf.score(X_train, Y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f22189eb-df3d-4186-8a2c-4286b9dc0f26",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = clf.predict_proba(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3c399218-0f70-4e3a-975d-e6ea41ae8d88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.49216341e-09, 9.99999999e-01],\n",
       "       [9.98001404e-01, 1.99859592e-03],\n",
       "       [1.77303085e-08, 9.99999982e-01],\n",
       "       ...,\n",
       "       [9.99999995e-01, 5.33623701e-09],\n",
       "       [2.67024636e-04, 9.99732975e-01],\n",
       "       [1.00000000e+00, 2.40800032e-14]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "430e451e-3e41-4c68-a1cc-a91ef7946507",
   "metadata": {},
   "outputs": [],
   "source": [
    "def slice_freq(Y, interval):\n",
    "    y_f = pd.Series(Y)\n",
    "    a = pd.cut(y_f,interval)\n",
    "    b=a.value_counts()\n",
    "    frequency = b.sort_index().values\n",
    "    return frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "dd857c37-eea6-4928-8fdc-de4347f30887",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SDR_multi(X_train, Y, K, h):\n",
    "    X = X_train\n",
    "    n = X.shape[0]\n",
    "    d_x = X.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "\n",
    "    interval = np.linspace(np.min(Y), np.max(Y), h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            if ph == 0:\n",
    "                mh = np.zeros(d_x)\n",
    "            else:\n",
    "                index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "                mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + la_g[0]*Gamma\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    beta1 = vectorspace_orthonormalization(v_PMS[:,:K])\n",
    "    #beta1 = v_PMS[:,:K]\n",
    "    return np.dot(beta1,beta1.T), np.eye(d_x)-np.dot(beta1,beta1.T) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7e751ab8-7a3e-4609-9ff7-a20de06780d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def iteration_estimation_v2(X_train, Y, Iter, q, h):\n",
    "    N_sample = len(X_train)\n",
    "    d = X_train.shape[1]\n",
    "    beta1 = np.eye(d)\n",
    "    \n",
    "    X_batch = X_train\n",
    "    n = X_batch.shape[0]\n",
    "    X = X_batch-np.mean(X_batch,0).reshape(1,d)\n",
    "    Cx = np.dot(X_batch.T,X_batch)/n\n",
    "    la, v = np.linalg.eig(Cx)\n",
    "    la = la.real\n",
    "    v = v.real\n",
    "    Cx12 = np.dot(np.dot(v,np.diag(la**(0.5))),v.T)\n",
    "    X_train = np.dot(X_batch,np.linalg.pinv(Cx12,rcond=1e-8))\n",
    "    beta1, beta2 = SDR_multi(X_train, Y, q, h)\n",
    "    if Iter == 0:\n",
    "        return beta1, beta2\n",
    "    else:\n",
    "        for j in range(Iter):\n",
    "            idx = np.random.rand(X_train.shape[0]) < 1\n",
    "            X_train = np.dot(X_train,beta1)\n",
    "            beta1, beta2 = SDR_multi(X_train[idx], Y[idx], q, h)\n",
    "            print(np.diag(beta1)[:5])\n",
    "    return beta1, beta2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "23dbe762-0be1-49c9-b085-584b2db777a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_vec = list(range(60,81,5))\n",
    "Iter_vec = [1]\n",
    "h = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "421ea8dc-8b34-4311-8773-55235f66b6d6",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter is 1 dim of subspace is 60\n",
      "[0.2544181  0.16928745 0.19203799 0.22119565 0.30440337]\n",
      "Iter is 1 dim of subspace is 65\n",
      "[0.25816793 0.19148486 0.20454973 0.22904384 0.32404595]\n",
      "Iter is 1 dim of subspace is 70\n",
      "[0.28206085 0.19803897 0.21570116 0.282366   0.33044335]\n",
      "Iter is 1 dim of subspace is 75\n",
      "[0.29069778 0.22084213 0.22405737 0.29377991 0.3389129 ]\n",
      "Iter is 1 dim of subspace is 80\n",
      "[0.31736804 0.25666796 0.24709929 0.31160829 0.37161865]\n"
     ]
    }
   ],
   "source": [
    "for j in range(len(Iter_vec)):\n",
    "    Iter = Iter_vec[j]\n",
    "    #df = pd.DataFrame(index = q_vec)\n",
    "    b1 = np.zeros(len(q_vec))\n",
    "    b2 = np.zeros(len(q_vec))\n",
    "\n",
    "    for k in range(len(q_vec)):\n",
    "        q = q_vec[k]\n",
    "        print(\"Iter is \"+ str(Iter)+ \" dim of subspace is \"+ str(q))\n",
    "\n",
    "        beta1, beta2 = iteration_estimation_v2(X_train, Y, Iter, q, h)\n",
    "        df_proj = pd.DataFrame(beta1)\n",
    "        df_proj.to_csv(\"results/\"+\"Iter_\"+str(Iter)+'_q_'+str(q)+'_'+str(N)+\"_Glove.csv\")\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e338d9-b43b-4fec-a2d9-c07845175a50",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e563c8e-19e9-471a-9b77-60a7789c56c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6112fb5-f270-4fb9-b0c0-8260d8272840",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab52f34-de84-4997-951b-f1d39b4f2a76",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
