{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agm_te.dataset import DataSet\n",
    "from agm_te.estimate import agm_estimate_TE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CUDA is available. Using GPU.\n"
     ]
    }
   ],
   "source": [
    "# Set up the compute device\n",
    "import torch; torch.set_printoptions(sci_mode=None)\n",
    "if torch.cuda.is_available(): # Check if CUDA is available\n",
    "    compute_device = torch.device(\"cuda\")\n",
    "    print(\"CUDA is available. Using GPU.\")\n",
    "else:\n",
    "    compute_device = torch.device(\"cpu\")\n",
    "    print(\"CUDA is not available. Using CPU.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the Simulator (from the `te_datasim` package) and generate time series for variables $X$ and $Y$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Analytical TE X -> Y:  0.0\n",
      "Analytical TE Y -> X:  0.1276 \n",
      "\n",
      "DataSet of 20000 timesteps across 10 trajectories for 2 variables:\n",
      "\t X is 1 dimensional\n",
      "\t Y is 1 dimensional\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from te_datasim.lineargaussian import BVLinearGaussianSimulator\n",
    "\n",
    "bivar = BVLinearGaussianSimulator()\n",
    "print(\"Analytical TE X -> Y: \", bivar.analytic_transfer_entropy('X', 'Y'))\n",
    "print(\"Analytical TE Y -> X: \", bivar.analytic_transfer_entropy('Y', 'X'), \"\\n\")\n",
    "\n",
    "bivar_data_dict = {'X':[], 'Y':[]}\n",
    "for i in range(10):\n",
    "    X, Y = bivar.simulate(2000, seed=i)\n",
    "    bivar_data_dict['X'].append(X)\n",
    "    bivar_data_dict['Y'].append(Y)\n",
    "\n",
    "bivar_data = DataSet(bivar_data_dict)\n",
    "print(bivar_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up the parameters for the AGMs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_params = {\n",
    "    'obs_model_type':   'Gaussian', # Gaussian observation model\n",
    "    'dyn_model_type':   'MLPTanh',  # Multi-layer perceptron with tanh activation dynamics model, 2 layers of 16 hidden units\n",
    "    'hidden_size':      16,\n",
    "    'num_layers':       2,\n",
    "    'compute_device':   compute_device # Use the compute device we set up (ideally GPU)\n",
    "}\n",
    "\n",
    "train_params = {\n",
    "    'batch_size':       1,\n",
    "    'epochs':           1000,\n",
    "    'optimize':         'sgd',\n",
    "    'learning_rate':    0.01,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimate $\\mathcal{T}_{X \\to Y}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated TE X -> Y:  -0.0                         \n"
     ]
    }
   ],
   "source": [
    "model_1, model_2, te = agm_estimate_TE(bivar_data, model_params, train_params, var_from='X', var_to='Y')\n",
    "print(\"Estimated TE X -> Y: \", te, \"                        \")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Estimate $\\mathcal{T}_{Y \\to X}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated TE Y -> X:  0.1266                         \n"
     ]
    }
   ],
   "source": [
    "model_1, model_2, te = agm_estimate_TE(bivar_data, model_params, train_params, var_from='Y', var_to='X')\n",
    "print(\"Estimated TE Y -> X: \", te, \"                        \")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
