{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2e57c4c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, \"./../\")\n",
    "import utils\n",
    "\n",
    "from FileManager import FileManager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0de08af7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing M*... done.\n"
     ]
    }
   ],
   "source": [
    "VOCAB_SZ = 10_000\n",
    "\n",
    "data_dir = os.path.join(os.getenv(\"DATASETPATH\"), \"qwem\")\n",
    "data_fm = FileManager(data_dir)\n",
    "\n",
    "data_fm.set_filepath(\"min500\")\n",
    "word_counts = data_fm.load(\"word_counts.pickle\")\n",
    "vocab = utils.Vocabulary(word_counts[:VOCAB_SZ])\n",
    "unigram = vocab.counts / vocab.counts.sum()\n",
    "\n",
    "print(f\"Computing M*... \", end=\"\")\n",
    "corpus_stats = data_fm.load(\"corpus_stats.pickle\")\n",
    "cL = corpus_stats[\"context_len\"]\n",
    "Cij, Crwij = corpus_stats[\"counts\"], corpus_stats[\"counts_reweight\"]\n",
    "numcounts = Cij[:VOCAB_SZ, :VOCAB_SZ].sum()\n",
    "Pij = Crwij[:VOCAB_SZ, :VOCAB_SZ] / (numcounts * (cL + 1)/2)\n",
    "PiPj = np.outer(unigram, unigram)\n",
    "Mstar = 2*(Pij - PiPj)/(Pij + PiPj)\n",
    "print(\"done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8a42386",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing eigenfeatures... done.\n"
     ]
    }
   ],
   "source": [
    "print(f\"Computing eigenfeatures... \", end=\"\")\n",
    "Mstar = torch.tensor(Mstar, dtype=torch.float64).cuda()\n",
    "eigvals, eigvecs = torch.linalg.eigh(Mstar)\n",
    "eigvals, eigvecs = eigvals.flip(dims=(0,)), eigvecs.flip(dims=(1,))\n",
    "assert torch.allclose(Mstar, eigvecs @ torch.diag(eigvals) @ eigvecs.T)\n",
    "eigvecs, eigvals = eigvecs.cpu().numpy(), eigvals.cpu().numpy()\n",
    "print(\"done.\")\n",
    "\n",
    "analysis_fm = FileManager(\"../analysis\")\n",
    "analysis_fm.save(eigvecs, \"mstar-eigvecs.npy\")\n",
    "analysis_fm.save(eigvals, \"mstar-eigvals.npy\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "win-research",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
