{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97e21764",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, './../')\n",
    "import json\n",
    "import utils\n",
    "\n",
    "from FileManager import FileManager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e54a81ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = os.path.join(os.getenv(\"DATASETPATH\"), \"qwem\")\n",
    "data_fm = FileManager(data_dir)\n",
    "\n",
    "analogy_dict = data_fm.load(\"analogies.pickle\")\n",
    "if analogy_dict is None:\n",
    "    raise FileNotFoundError(\"Analogy file not found.\")\n",
    "\n",
    "data_fm.set_filepath(\"min500\")\n",
    "word_counts = data_fm.load(\"word_counts.pickle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53fe6703",
   "metadata": {},
   "outputs": [],
   "source": [
    "expt_dir = os.path.join(os.getenv(\"QWEMPATH\"), \"stepwise-sgns\")\n",
    "fm = FileManager(expt_dir)\n",
    "with open(fm.get_filename(\"hypers.json\")) as f:\n",
    "    H = json.load(f)\n",
    "\n",
    "VOCAB_SZ = H[\"vocab_sz\"]\n",
    "EMBEDDIM = H[\"embeddim\"]\n",
    "vocab = utils.Vocabulary(word_counts[:VOCAB_SZ])\n",
    "unigram = vocab.counts / vocab.counts.sum()\n",
    "analogy_dataset = utils.AnalogyDataset(analogy_dict, vocab)\n",
    "\n",
    "save_fm = FileManager('../analysis/bench_models')\n",
    "models = save_fm.load(f\"models_d{EMBEDDIM}_V{VOCAB_SZ}.pickle\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aa8ee10",
   "metadata": {},
   "source": [
    "## Create and save models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3db94a8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_W(model_dir):\n",
    "    model_fm = FileManager(model_dir)\n",
    "    W = model_fm.load(\"W_final.npy\")\n",
    "    V, S, _ = np.linalg.svd(W, full_matrices=False)\n",
    "    W = V @ np.diag(S)\n",
    "    return W\n",
    "\n",
    "model_dir = os.path.join(os.getenv(\"QWEMPATH\"), \"stepwise-qwem/models\")\n",
    "W_QWEM = get_W(model_dir)\n",
    "\n",
    "model_dir = os.path.join(os.getenv(\"QWEMPATH\"), \"stepwise-sgns/models\")\n",
    "W_SGNS = get_W(model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "ad932a41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing M*... done.\n"
     ]
    }
   ],
   "source": [
    "print(f\"Computing M*... \", end=\"\")\n",
    "corpus_stats = data_fm.load(\"corpus_stats.pickle\")\n",
    "cL = corpus_stats[\"context_len\"]\n",
    "Cij, Crwij = corpus_stats[\"counts\"], corpus_stats[\"counts_reweight\"]\n",
    "numcounts = Cij[:VOCAB_SZ, :VOCAB_SZ].sum()\n",
    "Pij = Crwij[:VOCAB_SZ, :VOCAB_SZ] / (numcounts * (cL + 1)/2)\n",
    "PiPj = np.outer(unigram, unigram)\n",
    "Mstar = 2*(Pij - PiPj)/(Pij + PiPj)\n",
    "PMI = np.log((Pij / PiPj) + 1e-25)\n",
    "print(\"done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "dfde4853",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "starting.. done.\n",
      "starting.. done.\n",
      "starting.. done.\n"
     ]
    }
   ],
   "source": [
    "def get_W_from_M(M, d):\n",
    "    print(\"starting.. \", end='')\n",
    "    M = torch.tensor(M, dtype=torch.float64).cuda()\n",
    "    eigvals, eigvecs = torch.linalg.eigh(M)\n",
    "    eigvals, eigvecs = eigvals.flip(dims=(0,)), eigvecs.flip(dims=(1,))\n",
    "    eigvals, eigvecs = eigvals.cpu().numpy(), eigvecs.cpu().numpy()\n",
    "    W = eigvecs[:, :d] @ np.diag(np.sqrt(eigvals[:d]))\n",
    "    print(\"done.\")\n",
    "    return W\n",
    "\n",
    "W_Mstar = get_W_from_M(Mstar, EMBEDDIM)\n",
    "W_PMI = get_W_from_M(PMI, EMBEDDIM)\n",
    "W_PPMI = get_W_from_M(np.maximum(0, PMI), EMBEDDIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "045e647f",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    \"SGNS\": W_SGNS,\n",
    "    \"QWEM\": W_QWEM,\n",
    "    \"Mstar\": W_Mstar,\n",
    "    \"PPMI\": W_PPMI,\n",
    "    \"PMI\": W_PMI,\n",
    "}\n",
    "\n",
    "save_fm = FileManager('../analysis/bench_models')\n",
    "save_fm.save(models, f\"models_d{EMBEDDIM}_V{VOCAB_SZ}.pickle\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9114be2",
   "metadata": {},
   "source": [
    "## Eval benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "3b13ee9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1588\n",
      "249\n"
     ]
    }
   ],
   "source": [
    "from scipy.stats import spearmanr\n",
    "\n",
    "\n",
    "def read_similarity_data(file_path, vocab):\n",
    "    similarity_data = []\n",
    "    with open(file_path, 'r', encoding='utf-8') as file:\n",
    "        for line in file:\n",
    "            parts = line.strip().split()\n",
    "            if len(parts) == 3:\n",
    "                word1, word2, similarity = parts[0], parts[1], float(parts[2])\n",
    "                if word1 in vocab.word2token and word2 in vocab.word2token:\n",
    "                    similarity_data.append([vocab.word2token[word1],\n",
    "                                            vocab.word2token[word2],\n",
    "                                            similarity])\n",
    "    return similarity_data\n",
    "\n",
    "\n",
    "def evaluate_similarity(W, similarity_data):\n",
    "    norms = np.linalg.norm(W, axis=1, keepdims=True)\n",
    "    embeds = W / (norms + 1e-10)\n",
    "    \n",
    "    predicted_sims, human_sims = [], []\n",
    "    \n",
    "    for tok1, tok2, similarity in similarity_data:\n",
    "        predicted_sims.append(np.dot(embeds[tok1], embeds[tok2]).item())\n",
    "        human_sims.append(similarity)\n",
    "    \n",
    "    if len(predicted_sims) == 0:\n",
    "        raise ValueError(\"No valid word pairs found in embeddings.\")\n",
    "    \n",
    "    return spearmanr(predicted_sims, human_sims).correlation\n",
    "\n",
    "dataset_dir = os.getenv(\"DATASETPATH\")\n",
    "mendir = os.path.join(dataset_dir, \"qwem/benchmarks/MEN.txt\")\n",
    "ws353dir = os.path.join(dataset_dir, \"qwem/benchmarks/ws353.txt\")\n",
    "men_dataset = read_similarity_data(mendir, vocab)\n",
    "ws353_dataset = read_similarity_data(ws353dir, vocab)\n",
    "print(len(men_dataset))\n",
    "print(len(ws353_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "42bf616f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SGNS\n",
      "Analogy acc: 68.0\n",
      "MEN score: 0.744\n",
      "ws353 score: 0.6976\n",
      "\n",
      "QWEM\n",
      "Analogy acc: 65.1\n",
      "MEN score: 0.755\n",
      "ws353 score: 0.6815\n",
      "\n",
      "Mstar\n",
      "Analogy acc: 66.3\n",
      "MEN score: 0.755\n",
      "ws353 score: 0.6829\n",
      "\n",
      "PPMI\n",
      "Analogy acc: 50.6\n",
      "MEN score: 0.744\n",
      "ws353 score: 0.6904\n",
      "\n",
      "PMI\n",
      "Analogy acc: 8.4\n",
      "MEN score: 0.448\n",
      "ws353 score: 0.2206\n",
      "\n"
     ]
    }
   ],
   "source": [
    "benchmarks = [\"Google analogies\", \"MEN\", \"ws353\"]\n",
    "results = np.empty((len(models.items()), len(benchmarks)))\n",
    "for i, (k, W) in enumerate(models.items()):\n",
    "    print(k)\n",
    "    acc = analogy_dataset.eval_accuracy(W)\n",
    "    results[i, 0] = acc\n",
    "    print(f\"Analogy acc: {100*acc:.1f}\")\n",
    "    \n",
    "    rho = evaluate_similarity(W, men_dataset).mean().item()\n",
    "    results[i, 1] = rho\n",
    "    print(f\"MEN score: {rho:.3f}\")\n",
    "    \n",
    "    rho = evaluate_similarity(W, ws353_dataset).mean().item()\n",
    "    results[i, 2] = rho\n",
    "    print(f\"ws353 score: {rho:.4f}\")\n",
    "    \n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "179c61a8",
   "metadata": {},
   "source": [
    "## Eigenfeatures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "ba634f94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PCA dir 1\n",
      "-0.025 lemmon kitt socorro spacewatch fefefe id sort km median right peak households expatriate mount establishments hispanic survey footballers census races\n",
      "-0.519 eric cooper jones dennis oliver sam tom robinson roberts jack michael thompson miller harris scott lewis taylor alex barry wilson\n",
      "\n",
      "PCA dir 2\n",
      "0.475 furthermore requires can specific useful require therefore particular processes example typically such whereas component types specifically components additionally possible instance\n",
      "-0.589 jones dennis eric robinson scott michael oliver taylor roberts david miller harris smith lewis russell mitchell wilson thompson frank cooper\n",
      "\n",
      "PCA dir 3\n",
      "0.320 like can uses makes soft simple baby typical typically shape featuring eyes combination usually using surface dark charlie happy similar\n",
      "-0.410 government establishment governments foreign authorities leaders declared officials behalf civil political independence citizens sought union federal relations commission regime organisation\n",
      "\n",
      "PCA dir 9\n",
      "0.301 equipment operating provided enterprise company services designed equipped operate customers installed service maintenance vehicles mobile ltd companies vehicle purchase corporation\n",
      "-0.472 team win playoff tigers season consecutive playoffs winning giants finished league seasons scoring tied champions played wins lions losing victory\n",
      "\n",
      "PCA dir 10\n",
      "0.295 interpretation simple frame clearly object element simply reference view mounted context describes elements william theory indeed sense principle writings definition\n",
      "-0.333 food habitat affected populations plants diseases growth agricultural forests agriculture fish increase sugar areas disease grown plant growing sector farmers\n",
      "\n",
      "PCA dir 11\n",
      "0.390 deployed force combat forces patrol command naval squadron allied attack army armed troops war submarine invasion enemy artillery fighters missile\n",
      "-0.295 if property tax any pay must otherwise not shall accept maria apply money without buy villa de cannot limit granted\n",
      "\n",
      "PCA dir 12\n",
      "0.303 opposition opposed independence government backing guitar drums bill bass parties vocals producer produced ruling regime constitutional coalition democratic featured party\n",
      "-0.317 she her decides goes seeing reveals find tells asks wants sees herself begins everyone learn help tries meets hospital thinking\n",
      "\n",
      "PCA dir 13\n",
      "0.273 glass painted made skin pieces wooden clothing twice competition competitive legs food finishing competitions first winning finished occasions placed trophy\n",
      "-0.400 southwest northeast northwest north southeast boundary valley highway route southern east river south west lake along lies hills crossing creek\n",
      "\n",
      "PCA dir 14\n",
      "0.352 piano vocal orchestra solo music instrumental recordings songs choir tracks recording violin op symphony concert organ performances album lyrics musical\n",
      "-0.299 dragon clan appears voiced han spider uses dynasty princess giant uncle hero elder legend son software evil king mother daughter\n",
      "\n",
      "PCA dir 15\n",
      "0.318 wall shaped inside alleged arrest accused roof investigation criminal walls wooden floor glass interior arrested window crimes doors painted victim\n",
      "-0.261 england thus great price meant liverpool share earl enjoyed biggest came lord expected therefore ever anglo britain time last amount\n",
      "\n",
      "PCA dir 100\n",
      "0.190 org figure standing riding with http green date www parent despite whom relationship external link links close holding child archive\n",
      "-0.179 advertising newspaper senior newspapers promoted posted knight freedom magazines flying grand colonial rise post examples spread order honours range reporting\n",
      "\n"
     ]
    }
   ],
   "source": [
    "NORMALIZE = True\n",
    "# W = models[\"QWEM\"]\n",
    "W = models[\"PPMI\"]\n",
    "# W = models[\"SGNS\"]\n",
    "\n",
    "V, S, Ut = np.linalg.svd(W, full_matrices=False)\n",
    "assert np.allclose(np.abs(Ut), np.eye(EMBEDDIM))\n",
    "norms = np.linalg.norm(W, axis=1, keepdims=True) if NORMALIZE else 1\n",
    "embeds = W / norms\n",
    "\n",
    "dd = [1, 2, 3, 9, 10, 11, 12, 13, 14, 15, 100]\n",
    "for d in dd:\n",
    "    vec = embeds[:, d-1]\n",
    "    idxs = np.argsort(vec[:4000])[::-1]\n",
    "    vec_sort = vec[idxs]\n",
    "    print(f\"PCA dir {d}\")\n",
    "    print(f'{(vec_sort[:10]).mean():.3f} {vocab.to_words(idxs[:20])}')\n",
    "    print(f'{(vec_sort[-10:]).mean():.3f} {vocab.to_words(idxs[-20:][::-1])}')\n",
    "    print()   "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "win-research",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
