{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "065cf208-6aa6-45bd-b6d5-eb8c8ee4fae1",
   "metadata": {},
   "source": [
    "# AWGN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1e711648-051f-4e1b-9721-721a2f9f4571",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8e63aded-3ddc-4680-a090-ad47cd924167",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.estimators.knn import KSG\n",
    "from mutinfo.estimators.smi import SMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bdf3de64-91bc-4f7c-a097-ca3ed0afed6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import bebeziana\n",
    "\n",
    "bebeziana.seed_everything(42, [\"numpy\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "75fb5e86-fdea-4561-8ceb-8368a9114207",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 5\n",
    "N = 10000\n",
    "alpha = 0.01\n",
    "sigma = 0.1\n",
    "\n",
    "#A = np.tril(np.ones((d,d)))\n",
    "A = alpha * np.eye(d) + np.ones((d,d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "16ab5334-b034-4885-b7df-1d8eecab5ea3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.01, 1.  , 1.  , 1.  , 1.  ],\n",
       "       [1.  , 1.01, 1.  , 1.  , 1.  ],\n",
       "       [1.  , 1.  , 1.01, 1.  , 1.  ],\n",
       "       [1.  , 1.  , 1.  , 1.01, 1.  ],\n",
       "       [1.  , 1.  , 1.  , 1.  , 1.01]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4ed3d8f7-1839-46be-937e-e623339cc11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "samplers = {\n",
    "    \"uniform\": lambda : np.random.rand(N, d) @ A.T,\n",
    "    \"normal\":  lambda : np.random.randn(N, d) @ A.T,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3627ef40-c563-45ec-b91a-fd6b413de8fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator_factories = {\n",
    "    \"MI (KSG)\": lambda : KSG(),\n",
    "    \"1-SMI (KSG)\": lambda : SMI(KSG()),\n",
    "    \"2-SMI (KSG)\": lambda : SMI(KSG(), projection_dim=2),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "71a400a9-e09a-413c-91f6-e500bede2e46",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.linalg import sqrtm\n",
    "\n",
    "normalizers = {\n",
    "    \"whitening\":       lambda x : x @ np.linalg.inv(sqrtm(np.cov(x, rowvar=False))),\n",
    "    \"standardization\": lambda x : x / np.std(x, axis=0),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eb2015a1-a608-448a-a162-8d960abd9c48",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_runs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adc06d34-2ba9-4cee-809a-b9ea7a2c0fa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform MI (KSG) whitening\n",
      "7.476070495332041\n",
      "0.0075427928783968045\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform MI (KSG) standardization\n",
      "3.0443893277129463\n",
      "0.007025174108918013\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:04<00:00,  6.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform 1-SMI (KSG) whitening\n",
      "0.1692904595941308\n",
      "0.015285713337081966\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:00<00:00,  6.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform 1-SMI (KSG) standardization\n",
      "1.8241368064779064\n",
      "0.04435625394173614\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:42<00:00, 16.21s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform 2-SMI (KSG) whitening\n",
      "0.9622642112142479\n",
      "0.04081657179705653\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:56<00:00, 11.62s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform 2-SMI (KSG) standardization\n",
      "2.4631704930619933\n",
      "0.025279752269319777\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "normal MI (KSG) whitening\n",
      "7.488873460014633\n",
      "0.013498498683502209\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "normal MI (KSG) standardization\n",
      "3.044994970612822\n",
      "0.006879558850745593\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:09<00:00,  6.91s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "normal 1-SMI (KSG) whitening\n",
      "0.13932919184371712\n",
      "0.014590613751521164\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:02<00:00,  6.27s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "normal 1-SMI (KSG) standardization\n",
      "1.8345291608567706\n",
      "0.03752776443657762\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████████████████████████████████████████████████████████████████████████████▋                                          | 7/10 [01:56<00:48, 16.20s/it]"
     ]
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "\n",
    "for sampler_name, sampler in samplers.items():\n",
    "    for estimator_name, estimator_factory in estimator_factories.items():\n",
    "        for normalizer_name, normalizer in normalizers.items():\n",
    "            estimates = []\n",
    "            for run in trange(n_runs):\n",
    "                bebeziana.seed_everything(42 + run, [\"numpy\"])\n",
    "                \n",
    "                estimator = estimator_factory()\n",
    "                X = normalizer(sampler())\n",
    "                Z = sigma * np.random.randn(*X.shape)\n",
    "                Y = X + Z\n",
    "\n",
    "                estimates.append(estimator(X, Y))\n",
    "\n",
    "            estimates = np.array(estimates)\n",
    "            print(\" \".join([sampler_name, estimator_name, normalizer_name]))\n",
    "            print(estimates.mean())\n",
    "            print(estimates.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a1edc4f-b2e7-4000-8312-4593d2335312",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
