{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3eda6466-7a73-4c35-83d2-58b9920535e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import estimators.ksg_estimator_norm as ksgn\n",
    "import estimators.mine_estimator_norm as minen\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "496d37ab-dbb6-4884-a4dd-58bdcad0ccba",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "SAMPLE_SIZE = 60000\n",
    "SIGNAL_POWER = 3\n",
    "SIGNAL_NOISE = 0.2\n",
    "dim_n = 3\n",
    "\n",
    "X = np.random.normal(0., np.sqrt(SIGNAL_POWER), [SAMPLE_SIZE, dim_n])\n",
    "Y = X + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [SAMPLE_SIZE, dim_n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fdd2d51e-6e04-4af3-a117-02da11e2327d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated Mutual Information (Global KSG): 4.093520362341166\n",
      "Estimated Mutual Information (Local KSG): 4.0934964442282284\n",
      "Estimated Mutual Information (Global MINE): 6.27440383938345\n",
      "Estimated Mutual Information (Local MINE): 6.2678292866062675\n"
     ]
    }
   ],
   "source": [
    "ksg = ksgn.KSGEstimator(mode='gpu', kraskov_k=3)\n",
    "mi_result_global = ksg.global_norm_ksg_estimator(X, Y, C_y_min=0.5, C_y_max=2)\n",
    "mi_result_local = ksg.local_norm_ksg_estimator(X, Y)\n",
    "print(f\"Estimated Mutual Information (Global KSG): {mi_result_global}\")\n",
    "print(f\"Estimated Mutual Information (Local KSG): {mi_result_local}\")\n",
    "\n",
    "\n",
    "mine_estimator = minen.MINEEstimator(num_epoch=50)\n",
    "mi_result_global = mine_estimator.global_norm_mine_estimator(X, Y)\n",
    "mi_result_local = mine_estimator.local_norm_mine_estimator(X, Y)\n",
    "\n",
    "print(f\"Estimated Mutual Information (Global MINE): {mi_result_global}\")\n",
    "print(f\"Estimated Mutual Information (Local MINE): {mi_result_local}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "230ed863-0afb-472e-abbf-fefb09452074",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "true_mi= 6.0\n"
     ]
    }
   ],
   "source": [
    "def theoretic_mutual_information_gaussian(power,noise):\n",
    "    return dim_n*0.5*np.log2(1 + power/noise)\n",
    "\n",
    "true_mi = theoretic_mutual_information_gaussian(SIGNAL_POWER, SIGNAL_NOISE)\n",
    "print('true_mi=',true_mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecce84f3-dcbd-4b0d-97a5-368e9040f05f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
