{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "28b4e4d7",
   "metadata": {},
   "source": [
    "# Application"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69ab9214",
   "metadata": {},
   "source": [
    "The code can be divided into the following three parts：\n",
    "\n",
    "1. Read text files that represent different domains.\n",
    "\n",
    "2. Encode each line in the text files with a pretrained model.\n",
    "\n",
    "3. Cluster the encodings，using a GMM-k, k=5. In this step, the PCA method is used to reduce the dimension.The part involving the covariance function is replaced by our proposed four sparsified covariance estimators -- RK, RK-Spat, BS, BS-Spat, and the corresponding functional principle components (FPC) scores are used as the extracted features in the modeling step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "42cea4a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoding medical_dev with transformers...\n",
      "[('Pack size of 1 and 3 pens.\\n', 1), ('Effects of tadalafil on other medicinal products\\n', 1), ('• Your doctor will prescribe Truvada with other antiretroviral medicines.\\n', 1), ('Type 2 diabetes is a disease in which the pancreas does not make enough insulin to control the level of glucose in the blood or when the body is unable to use insulin effectively.\\n', 1), ('Number of delegates\\n', 1), ('What is Fertavid?\\n', 1), ('The EDGE and EDGE II studies compared the gastrointestinal tolerability of etoricoxib versus diclofenac.\\n', 1), ('For more information, contact the manufacturer.\\n', 1), ('135 Do not store above 30 °C.\\n', 1), ('12 single dose vials, 12 injection syringes, 12 injection needles and 12 cleansing swabs uc rod\\n', 1)]\n",
      "found 2000 lines\n",
      "encoding with roberta-base...\n",
      "encoded with roberta-base in 159.33207321166992 seconds\n",
      "encoding with roberta-large...\n",
      "encoded with roberta-large in 657.2413794994354 seconds\n",
      "encoding with bert-base-uncased...\n",
      "encoded with bert-base-uncased in 163.80584406852722 seconds\n",
      "encoding with bert-large-cased...\n",
      "encoded with bert-large-cased in 653.3901450634003 seconds\n",
      "encoding with distilbert-base-uncased...\n",
      "encoded with distilbert-base-uncased in 82.24364709854126 seconds\n",
      "encoding with gpt2...\n",
      "encoded with gpt2 in 213.47763442993164 seconds\n",
      "encoding with xlnet-base-cased...\n",
      "encoded with xlnet-base-cased in 237.16789841651917 seconds\n",
      "encoding it_dev with transformers...\n",
      "[('nitrogen\\n', 1), ('FIND( \" FindText \" ; \" Text \" ; Position)\\n', 1), ('flowcharts\\n', 1), ('Depending on the chart type, the texts are shown on the x-axis or as data labels.\\n', 1), ('On the Fly Operations (JPEG only)\\n', 1), ('The HTML and plaintext formats are provided as tar files compressed using the bzip2 archiver.\\n', 1), ('Additional options are available on a menu accessed by right-clicking on a Budget.\\n', 1), ('http: / /www. kexi-project. org/ support. html\\n', 1), ('is the content to be tested.\\n', 1), ('Not every & kmahjongg; game can be finished. Sometimes the tiles are mixed in such a way that it may not be possible to find all the matches. This problem can be avoided. Please refer to this section of the configuration dialog.\\n', 1)]\n",
      "found 2000 lines\n",
      "encoding with roberta-base...\n",
      "encoded with roberta-base in 126.22167706489563 seconds\n",
      "encoding with roberta-large...\n",
      "encoded with roberta-large in 551.7870955467224 seconds\n",
      "encoding with bert-base-uncased...\n",
      "encoded with bert-base-uncased in 130.78419876098633 seconds\n",
      "encoding with bert-large-cased...\n",
      "encoded with bert-large-cased in 556.1838870048523 seconds\n",
      "encoding with distilbert-base-uncased...\n",
      "encoded with distilbert-base-uncased in 66.13374900817871 seconds\n",
      "encoding with gpt2...\n",
      "encoded with gpt2 in 184.20440411567688 seconds\n",
      "encoding with xlnet-base-cased...\n",
      "encoded with xlnet-base-cased in 191.21682143211365 seconds\n",
      "encoding koran_dev with transformers...\n",
      "[('He created man and surely know what misdoubts arise in their hearts; for We are closer to him than his jugular vein.\\n', 1), ('And the disbelievers are the heirs of one another – if you do not do so, there will be turmoil in the land and a great chaos.\\n', 1), ('“And they have misled a large number; and (I pray that) You increase nothing for the unjust except error.”\\n', 1), (\"haply I shall do righteousness in that I forsook.'\\n\", 1), ('But did not add: \"If God may please.\"\\n', 1), (\"Say: 'Who sent down the Book that Moses brought as a light and a guidance to men?\\n\", 1), ('\"The punishment for that (should be),\" they said, \"that he in whose luggage it is found should be held as punishment.\\n', 1), (\"We have made lawful for you, O Prophet, wives to whom you have given their dower, and God-given maids and captives you have married, and the daughters of your father's brothers and daughters of your father's sisters, and daughters of your mother's brothers and sisters, who migrated with you; and a believing woman who offers herself to the Prophet if the Prophet desires to marry her.\\n\", 1), ('With Us is a Record that preserves everything.\\n', 1), ('What they spend in the life of this world is like a frosty wind which smites and destroys the crops of a people who had wronged themselves.\\n', 1)]\n",
      "found 2000 lines\n",
      "encoding with roberta-base...\n",
      "encoded with roberta-base in 165.14546275138855 seconds\n",
      "encoding with roberta-large...\n",
      "encoded with roberta-large in 640.5100984573364 seconds\n",
      "encoding with bert-base-uncased...\n",
      "encoded with bert-base-uncased in 167.91539525985718 seconds\n",
      "encoding with bert-large-cased...\n",
      "encoded with bert-large-cased in 650.7303652763367 seconds\n",
      "encoding with distilbert-base-uncased...\n",
      "encoded with distilbert-base-uncased in 86.02785634994507 seconds\n",
      "encoding with gpt2...\n",
      "encoded with gpt2 in 215.74912858009338 seconds\n",
      "encoding with xlnet-base-cased...\n",
      "encoded with xlnet-base-cased in 239.5588619709015 seconds\n",
      "encoding subtitles_dev with transformers...\n",
      "[(\"You've bandaged it too tight.\\n\", 1), ('Elijah, you keep watch.\\n', 1), (\"I've heard the Japanese have killed hundreds of people in the north.\\n\", 1), ('In recent years, with the modern technology of high ISO cameras, fast shutter speeds, that means I can freeze the motion far better than I ever could have, just a few years ago.\\n', 1), ('- Did he tell you where he went?\\n', 1), ('No, that went right by me.\\n', 1), ('I have two tickets to the Norman Rockwell exhibit at the Smithsonian.\\n', 1), ('But they were all decommissioned by the time Clinton took office.\\n', 1), ('It never went out of there.\\n', 1), (\"not the I Once Had Three Shows on Broadway Simultaneously But I Blew All My Money on Coke and Now Here I Am Play, so let's do this thing!\\n\", 1)]\n",
      "found 2000 lines\n",
      "encoding with roberta-base...\n",
      "encoded with roberta-base in 114.55825662612915 seconds\n",
      "encoding with roberta-large...\n",
      "encoded with roberta-large in 486.2361433506012 seconds\n",
      "encoding with bert-base-uncased...\n",
      "encoded with bert-base-uncased in 129.22134065628052 seconds\n",
      "encoding with bert-large-cased...\n",
      "encoded with bert-large-cased in 525.8559105396271 seconds\n",
      "encoding with distilbert-base-uncased...\n",
      "encoded with distilbert-base-uncased in 58.48565602302551 seconds\n",
      "encoding with gpt2...\n",
      "encoded with gpt2 in 182.45723509788513 seconds\n",
      "encoding with xlnet-base-cased...\n",
      "encoded with xlnet-base-cased in 177.89297652244568 seconds\n",
      "encoding law_dev with transformers...\n",
      "[('(6) Whereas such investigations should be carried out under equivalent conditions in all the Community institutions, bodies and offices and agencies; whereas assignment of this task to the Office should not affect the responsibilities of the institutions, bodies, offices or agencies themselves and should in no way reduce the legal protection of the persons concerned;\\n', 1), ('- Viruses and virus-like organisms\\n', 1), (\"(c) the area within 12 miles of the west coast of Denmark from latitude 57° 00' N as far north as the Hirtshals Lighthouse, measured from the baselines.\\n\", 1), ('- for glucose and glucose syrup, the amount of the export refund for such products in their unprocessed state, fixed for each of those products in accordance with Article 13 of Council Regulation (EEC) No 1766/92 of 30 June 1992 on the common organization of the market in cereals (9), and its implementing provisions.\\n', 1), ('The Annex to Decision 2002/253/EC is amended in accordance with Annex III to this Decision.\\n', 1), ('(d) Member States which introduce the system referred to in (c) shall apply to the Commission to have it approved under the procedure provided for in Article 13(2).\\n', 1), ('(4) The Commission will inform the Advisory Committee established by Article 13 of the Directive on the criteria,\\n', 1), ('Over the same period, production destined for sales fell by over 14 %, from 2,16mt to 1,85mt.\\n', 1), ('Whereas controls carried out on the denaturing agents defined in this Regulation have shown that results for the same product may sometimes differ from those set out in the definitions ; whereas these results may not be strictly accurate ; whereas a certain technical latitude should be allowed with regard to the minimum contents specified;\\n', 1), ('(6) The Commission officially advised the complainant, other Community producers, the exporting producers in the PRC and the USA, importers and users known to be concerned and representatives of the PRC and USA governments, of the opening of the proceedings.\\n', 1)]\n",
      "found 2000 lines\n",
      "encoding with roberta-base...\n",
      "encoded with roberta-base in 208.59782218933105 seconds\n",
      "encoding with roberta-large...\n",
      "encoded with roberta-large in 773.0728940963745 seconds\n",
      "encoding with bert-base-uncased...\n",
      "encoded with bert-base-uncased in 208.24601030349731 seconds\n",
      "encoding with bert-large-cased...\n",
      "encoded with bert-large-cased in 794.4485681056976 seconds\n",
      "encoding with distilbert-base-uncased...\n",
      "encoded with distilbert-base-uncased in 104.98448300361633 seconds\n",
      "encoding with gpt2...\n",
      "encoded with gpt2 in 250.78755974769592 seconds\n",
      "encoding with xlnet-base-cased...\n",
      "encoded with xlnet-base-cased in 286.736314535141 seconds\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import time\n",
    "from transformers import *\n",
    "import numpy as np\n",
    "import gensim\n",
    "from transformers import *\n",
    "from collections import Counter\n",
    "from collections import defaultdict\n",
    "from sklearn import svm, datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.utils.multiclass import unique_labels\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "MODELS = [\n",
    "        (RobertaModel, RobertaTokenizer, 'roberta-base'),\n",
    "        (RobertaModel, RobertaTokenizer, 'roberta-large'),\n",
    "        (BertModel, BertTokenizer, 'bert-base-uncased'),\n",
    "        (BertModel, BertTokenizer, 'bert-large-cased'),\n",
    "        (DistilBertModel, DistilBertTokenizer, 'distilbert-base-uncased'),\n",
    "        (OpenAIGPTModel,  OpenAIGPTTokenizer,  'openai-gpt'),\n",
    "        (GPT2Model, GPT2Tokenizer, 'gpt2'),\n",
    "        (CTRLModel,       CTRLTokenizer,       'ctrl'),\n",
    "        (TransfoXLModel,  TransfoXLTokenizer,  'transfo-xl-wt103'),\n",
    "        (XLNetModel, XLNetTokenizer, 'xlnet-base-cased'),\n",
    "        (XLMModel,        XLMTokenizer,        'xlm-mlm-enfr-1024'),\n",
    "        (BertModel, BertTokenizer, 'bert-base-multilingual-cased'),\n",
    "        (XLMRobertaModel, XLMRobertaTokenizer, 'xlm-roberta-base'),\n",
    "        (XLMRobertaModel, XLMRobertaTokenizer, 'xlm-roberta-large')\n",
    "    ]\n",
    "\n",
    "\n",
    "def encode_with_transformers(corpus, models_to_use = ['roberta-large']):\n",
    "    model_to_states = {}\n",
    "    for model_class, tokenizer_class, model_name in MODELS:\n",
    "        if model_name not in models_to_use:\n",
    "            continue\n",
    "        print('encoding with {}...'.format(model_name))\n",
    "        model_to_states[model_name] = {}\n",
    "        model_to_states[model_name]['states'] = []\n",
    "        model_to_states[model_name]['sents'] = []\n",
    "\n",
    "        # Load pretrained model/tokenizer\n",
    "        tokenizer = tokenizer_class.from_pretrained(model_name)\n",
    "        model = model_class.from_pretrained(model_name)\n",
    "        model.to(torch.device('cuda' if torch. cuda.is_available () else \"cpu\"))\n",
    "        \n",
    "        # Encode text\n",
    "        start = time.time()\n",
    "        for sentence in corpus:\n",
    "            model_to_states[model_name]['sents'].append(sentence)\n",
    "            input_ids = torch.tensor([tokenizer.encode(sentence, add_special_tokens=True, max_length=128)])  \n",
    "            input_ids = input_ids.to(torch.device('cuda' if torch. cuda.is_available () else \"cpu\"))\n",
    "            with torch.no_grad():\n",
    "                output = model(input_ids)\n",
    "                last_hidden_states = output[0]\n",
    "                \n",
    "                # avg pool last hidden layer\n",
    "                squeezed = last_hidden_states.squeeze(dim=0)\n",
    "                masked = squeezed[:input_ids.shape[1],:]\n",
    "                avg_pooled = masked.mean(dim=0)                \n",
    "                model_to_states[model_name]['states'].append(avg_pooled.cpu())\n",
    "                \n",
    "        end = time.time()\n",
    "        print('encoded with {} in {} seconds'.format(model_name, end - start))\n",
    "        np_tensors = [np.array(tensor) for tensor in model_to_states[model_name]['states']]\n",
    "        model_to_states[model_name]['states'] = np.stack(np_tensors)\n",
    "    return model_to_states\n",
    "\n",
    "def map_clusters_to_classes_by_majority(y_train, y_train_pred):\n",
    "    cluster_to_class = {}\n",
    "    for cluster in np.unique(y_train_pred):\n",
    "        # run on indices where this is the cluster\n",
    "        original_classes = []\n",
    "        for i, pred in enumerate(y_train_pred):\n",
    "            if pred == cluster:\n",
    "                original_classes.append(y_train[i])\n",
    "        # take majority         \n",
    "        cluster_to_class[cluster] = max(set(original_classes), key = original_classes.count)\n",
    "    return cluster_to_class      \n",
    "\n",
    "# Load pretrained word2vec model \n",
    "# The model file can be downloaded from here: https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing\n",
    "model = gensim.models.KeyedVectors.load_word2vec_format('new-root\\GoogleNews-vectors-negative300.bin', binary=True)\n",
    "\n",
    "def encode_sent_with_w2v_cbow(sent, model, max_pool=False):\n",
    "    MODEL_SIZE = 300\n",
    "    MAX_LEN = 50\n",
    "    toks = list(gensim.utils.tokenize(sent))[:MAX_LEN]\n",
    "    vecs = []\n",
    "    for tok in toks:\n",
    "        if tok in model:\n",
    "            vecs.append(model[tok])\n",
    "    if len(vecs):\n",
    "        if max_pool:\n",
    "            pooled = np.max(np.stack(vecs), axis=0)\n",
    "        else:\n",
    "            pooled = np.mean(np.stack(vecs), axis=0)\n",
    "    else:\n",
    "        pooled = model['unk']\n",
    "    return pooled, toks\n",
    "\n",
    "\n",
    "def get_w2v_sent_reps(file_path, model, limit=3000, max_pool=False):\n",
    "    vecs = [] \n",
    "    toks = []\n",
    "    with open(file_path, encoding = 'utf_8') as f:\n",
    "        lines = f.readlines()\n",
    "        if len(lines) > limit:\n",
    "            lines = lines[:limit]\n",
    "        for line in lines:\n",
    "            vec, sent_toks = encode_sent_with_w2v_cbow(line.strip(), model, max_pool)\n",
    "            vecs.append(vec)\n",
    "            toks.append(sent_toks)\n",
    "    return np.stack(vecs), toks\n",
    "\n",
    "# the split text files can be download from https://drive.google.com/file/d/1yvB-pvlojtT2UpOX1JvwtD6rw9joQ49A/view\n",
    "base_path_new = 'new-root\\multi_domain_new_split'\n",
    "file_paths_new = {\n",
    "                'medical_dev':base_path_new + '\\medical\\dev.en',\n",
    "                'it_dev':base_path_new + '\\it\\dev.en',\n",
    "                'koran_dev':base_path_new + '\\koran\\dev.en',\n",
    "                'subtitles_dev':base_path_new + '\\subtitles\\dev.en',\n",
    "                'law_dev': base_path_new + '\\law\\dev.en'\n",
    "             }\n",
    "\n",
    "models_to_use = ['bert-base-uncased','bert-large-cased', 'distilbert-base-uncased', 'roberta-base', 'roberta-large','gpt2', 'xlnet-base-cased']    #encode with roberta-base model\n",
    "\n",
    "model_to_domain_to_encodings_new = defaultdict(dict)\n",
    "for domain_name in file_paths_new:\n",
    "    print('encoding {} with transformers...'.format(domain_name))\n",
    "    file_path = file_paths_new[domain_name]\n",
    "    counter = Counter(open(file_path, encoding = 'utf_8').readlines())\n",
    "    print(counter.most_common(10))\n",
    "    lines = list(set(open(file_path, encoding = 'utf_8').readlines())) # eliminate duplicate sentences\n",
    "    print('found {} lines'.format(len(lines)))\n",
    "    res = encode_with_transformers(lines, models_to_use)\n",
    "    for model_name in models_to_use:\n",
    "        model_to_domain_to_encodings_new[model_name][domain_name] = res[model_name]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "ff7af84c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoded medical_dev in 1.7659687995910645 seconds\n",
      "encoded it_dev in 0.5115230083465576 seconds\n",
      "encoded koran_dev in 0.5432131290435791 seconds\n",
      "encoded subtitles_dev in 0.2175920009613037 seconds\n",
      "encoded law_dev in 0.48760223388671875 seconds\n"
     ]
    }
   ],
   "source": [
    "# word2vec  \n",
    "for domain_name in file_paths_new:\n",
    "    file_path = file_paths_new[domain_name]\n",
    "    start = time.time()\n",
    "    vecs, sents = get_w2v_sent_reps(file_path, model)\n",
    "    end = time.time()\n",
    "    print('encoded {} in {} seconds'.format(domain_name, end - start))\n",
    "    model_to_domain_to_encodings_new['word2vec'][domain_name] = {'states':vecs, 'sents':sents}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "905ce304",
   "metadata": {},
   "source": [
    "After PCA dimension reduction and feature extraction, the specific process of completing the classification task is as follows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "de9f9c6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "def pcafun(X,k):#standard PCA\n",
    "  n_samples, n_features = X.shape\n",
    "  mean = np.array([np.mean(X[:,i]) for i in range(n_features)])\n",
    "  #normalization\n",
    "  norm_X = X-mean\n",
    "  scatter_matrix = np.dot(np.transpose(norm_X),norm_X)\n",
    "  #eigenvectors and eigenvalues\n",
    "  eig_val, eig_vec = np.linalg.eig(scatter_matrix)\n",
    "  eig_val = eig_val.real\n",
    "  eig_vec = eig_vec.real\n",
    "  eig_pairs = [(np.abs(eig_val[i]), eig_vec[:,i]) for i in range(n_features)]\n",
    "  # sort eig_vec based on eig_val from highest to lowest\n",
    "  # select the top k eig_vec\n",
    "  feature=np.array([ele[1] for ele in eig_pairs[:k]])\n",
    "  #extracted features\n",
    "  data = np.dot(norm_X,np.transpose(feature))\n",
    "  return data\n",
    "\n",
    "def fit_gmm(name_to_embeddings, class_names, first_principal_component_shown=0, \n",
    "            last_principal_component_shown=1, clusters=5, pca=True, examples_per_class = 2000):\n",
    "    all_states = []\n",
    "    all_sents = []\n",
    "    num_classes = len(class_names)\n",
    "    if last_principal_component_shown <= first_principal_component_shown:\n",
    "        raise Exception('first PCA component must be smaller than the 2nd')\n",
    "    \n",
    "    # Concatenate the data to one matrix\n",
    "    for label in class_names:\n",
    "        all_states.append(name_to_embeddings[label]['states'][0:examples_per_class])\n",
    "        all_sents += name_to_embeddings[label]['sents']\n",
    "    concat_all_embs = np.concatenate(all_states)\n",
    "    \n",
    "    # Compute PCA\n",
    "    if pca:\n",
    "        pca_data = pcafun(concat_all_embs,last_principal_component_shown)\n",
    "    else:\n",
    "        pca_data = concat_all_embs\n",
    "        \n",
    "    pca_labels = []\n",
    "    for i in range(len(class_names)):\n",
    "        for j in range(examples_per_class):\n",
    "            pca_labels.append(i)\n",
    "    pca_labels = np.array(pca_labels)\n",
    "\n",
    "    train_index = list(range(0, pca_data.shape[0]))\n",
    "    test_index = list(range(0, pca_data.shape[0]))\n",
    "\n",
    "    X_train = pca_data[train_index]\n",
    "    y_train = pca_labels[train_index]\n",
    "    X_test = pca_data[test_index]\n",
    "    y_test = pca_labels[test_index]\n",
    "\n",
    "    n_classes = len(np.unique(y_train))\n",
    "    if clusters > 0:\n",
    "        n_clusters = clusters\n",
    "    else:\n",
    "        n_clusters = n_classes\n",
    "    \n",
    "    estimators = {cov_type: GaussianMixture(n_components=n_clusters,\n",
    "                  covariance_type=cov_type, max_iter=150, random_state=0)\n",
    "                  for cov_type in ['full']} \n",
    "\n",
    "    n_estimators = len(estimators)\n",
    "\n",
    "    best_accuracy = 0\n",
    "    for index, (name, estimator) in enumerate(estimators.items()):\n",
    "        \n",
    "        # train the GMM         \n",
    "        estimator.fit(X_train)       \n",
    "        y_train_pred = estimator.predict(X_train)\n",
    "        y_test_pred = estimator.predict(X_test)       \n",
    "        clusters_to_classes = map_clusters_to_classes_by_majority(y_train, y_train_pred)\n",
    "        \n",
    "        # Purity metric         \n",
    "        count=0\n",
    "        for i, pred in enumerate(y_train_pred):\n",
    "            if clusters_to_classes[pred] == y_train[i]:\n",
    "                count += 1\n",
    "        train_accuracy = float(count)/len(y_train_pred) * 100\n",
    "        \n",
    "        correct_count=0\n",
    "        for i, pred in enumerate(y_test_pred):\n",
    "            if clusters_to_classes[pred] == y_test[i]:\n",
    "                correct_count += 1\n",
    "        test_accuracy = float(correct_count)/len(y_test_pred) * 100\n",
    "        \n",
    "        if test_accuracy > best_accuracy:\n",
    "            best_accuracy = test_accuracy\n",
    "        \n",
    "    return best_accuracy\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "09c75e97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(<class 'list'>, {'bert-base-uncased': [78.62], 'bert-large-cased': [88.94999999999999], 'distilbert-base-uncased': [87.53999999999999], 'roberta-base': [78.95], 'roberta-large': [73.74000000000001], 'gpt2': [70.30999999999999], 'xlnet-base-cased': [56.589999999999996], 'word2vec': [54.37]})\n"
     ]
    }
   ],
   "source": [
    "domains = ['it_dev', 'koran_dev', 'subtitles_dev', 'medical_dev', 'law_dev']\n",
    "first_principal = 1\n",
    "last_principal = 50\n",
    "num_clusters = 5\n",
    "use_pca = True\n",
    "\n",
    "model_to_accuracies = defaultdict(list)\n",
    "for model_name in model_to_domain_to_encodings_new:\n",
    "    accuracy = fit_gmm(model_to_domain_to_encodings_new[model_name], domains, \n",
    "                       first_principal_component_shown = first_principal,\n",
    "                       last_principal_component_shown = last_principal, \n",
    "                       clusters = num_clusters, pca = use_pca)\n",
    "    model_to_accuracies[model_name].append(accuracy)\n",
    "#standard PCA\n",
    "print(model_to_accuracies)                                             "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0345ff9",
   "metadata": {},
   "source": [
    "For different sparsified estimation of the covariance function, only the function pcafun needs to be changed. Take Random-knots estimator as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "7ef77b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "def pcafun(X,k):#PCA with random-knots covariance\n",
    "  n_samples, n_features = X.shape\n",
    "  norm_X=X-np.array([np.mean(X[:,i]) for i in range(n_features)])\n",
    "  xnew = X*0\n",
    "  gamma = 5/8\n",
    "  Js = round(3*n_features**gamma*math.log(math.log(n_features)))         \n",
    "  for i in range(n_samples):\n",
    "      random_nums = random.sample(range(n_features),Js)\n",
    "      for k1 in random_nums:\n",
    "          xnew[i,k1] = X[i,k1]\n",
    "  mean = np.array([np.mean(xnew[:,i]) for i in range(n_features)])\n",
    "  norm_xnew=xnew-mean\n",
    "  scatter_matrix = ((n_features/Js)**2)*np.dot(np.transpose(norm_xnew),norm_xnew)\n",
    "  eig_val, eig_vec = np.linalg.eig(scatter_matrix)\n",
    "  eig_val = eig_val.real\n",
    "  eig_vec = eig_vec.real\n",
    "  eig_pairs = [(np.abs(eig_val[i]), eig_vec[:,i]) for i in range(n_features)]\n",
    "  feature=np.array([ele[1] for ele in eig_pairs[:k]])\n",
    "  data=np.dot(norm_X,np.transpose(feature))\n",
    "  return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3431e12",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
