{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from util_results import Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from te_datasim.jointprocess import MVJointProcessSimulator\n",
    "from te_datasim.lineargaussian import MVLinearGaussianSimulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CUDA is available. Using GPU.\n"
     ]
    }
   ],
   "source": [
    "import torch; torch.set_printoptions(sci_mode=None)\n",
    "# Check if CUDA is available\n",
    "if torch.cuda.is_available():\n",
    "    compute_device = torch.device(\"cuda\")\n",
    "    print(\"CUDA is available. Using GPU.\")\n",
    "else:\n",
    "    compute_device = torch.device(\"cpu\")\n",
    "    print(\"CUDA is not available. Using CPU.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agm_te.model import _train_agm\n",
    "from agm_te.dataset import DataSet\n",
    "from agm_te.estimate import init_agms_from_loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(generator, n_samples, batch_size, seed):\n",
    "    data_dict = {'X':[], 'Y':[]}\n",
    "    \n",
    "    X, Y = generator.simulate(n_samples, seed=seed)\n",
    "    for i in range(0, n_samples, batch_size):\n",
    "        data_dict['X'].append(X[i:i+batch_size])\n",
    "        data_dict['Y'].append(Y[i:i+batch_size])\n",
    "    data = DataSet(data_dict)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 1000\n",
    "LR = 0.01\n",
    "L2P = 0.00\n",
    "OPT = 'sgd'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def TE_agmte(dataset, device, var_from, var_to, hidden_size, batch_size, plot_loss=False):\n",
    "    smoothing = 100\n",
    "    \n",
    "    tl_dataloader_1, tl_dataloader_2 = dataset.get_TE_dataloaders(var_from=var_from, var_to=var_to)\n",
    "    tl_model_1, tl_model_2 = init_agms_from_loaders(tl_dataloader_1, tl_dataloader_2, dyn_model_type='MLPTanh', hidden_size=hidden_size, num_layers=2)\n",
    "    tl_model_1.to(device)\n",
    "    tl_model_2.to(device)\n",
    "    tl_model_1, loss_1 = _train_agm(tl_model_1, tl_dataloader_1,\n",
    "                                    batch_size=batch_size, epochs=EPOCHS, learning_rate=0.01, optimize='sgd')\n",
    "    print()\n",
    "    tl_model_2, loss_2 = _train_agm(tl_model_2, tl_dataloader_2, \n",
    "                                    batch_size=batch_size, epochs=EPOCHS, learning_rate=0.01, optimize='sgd')\n",
    "    print()\n",
    "    if plot_loss:\n",
    "        plt.figure(figsize=(12, 4))\n",
    "        plt.plot(loss_1, alpha=0.5, color='blue')\n",
    "        plt.plot(loss_2, alpha=0.5, color='orange')\n",
    "\n",
    "    smooth_loss_1 = np.array([np.mean(loss_1[i:i+smoothing]) for i in range(0, len(loss_1)-smoothing)])\n",
    "    smooth_loss_2 = np.array([np.mean(loss_2[i:i+smoothing]) for i in range(0, len(loss_2)-smoothing)])\n",
    "    \n",
    "    if plot_loss:\n",
    "        plt.plot(range(smoothing, len(loss_1)), smooth_loss_1, color='blue', linewidth=2, label='Loss 1')\n",
    "        plt.plot(range(smoothing, len(loss_1)), smooth_loss_2, color='orange', linewidth=2, label='Loss 2')\n",
    "        plt.legend()\n",
    "        plt.show()\n",
    "    \n",
    "    TE = np.round(smooth_loss_1[-1]-smooth_loss_2[-1],4)\n",
    "    print(f\"TE: {var_from}->{var_to} {TE}\")\n",
    "    return TE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Validity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "REPLICATES = 5\n",
    "SAMPLE_SIZE = 10000\n",
    "BATCH_SIZE = 500\n",
    "NB = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear Gaussian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify the range of lambda values to test\n",
    "lg_lambda_range = list(np.linspace(0, 1, 9, endpoint=True))\n",
    "\n",
    "# Initialize the list of generators with one for each lambda value\n",
    "lg_generator_lst = [MVLinearGaussianSimulator(n_dim=1, coupling=lam) for lam in lg_lambda_range]\n",
    "\n",
    "# get the reference values\n",
    "lg_TE_X2Y_ref_lst = [generator.analytic_transfer_entropy('X', 'Y') for generator in lg_generator_lst]\n",
    "lg_TE_Y2X_ref_lst = [generator.analytic_transfer_entropy('Y', 'X') for generator in lg_generator_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "# Coupling =  0.0 True TE X->Y =  0.0 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 0.59990708\n",
      "Epoch [990/1000], Loss: 0.59973125\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.60939604\n",
      "Epoch [990/1000], Loss: 0.60893184\n",
      "TE: Y->X 0.0005\n",
      "# Coupling =  0.125 True TE X->Y =  0.0 True TE Y->X =  0.0092\n",
      "Epoch [990/1000], Loss: 0.60000257\n",
      "Epoch [990/1000], Loss: 0.59968642\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.61861269\n",
      "Epoch [990/1000], Loss: 0.60882527\n",
      "TE: Y->X 0.0098\n",
      "# Coupling =  0.25 True TE X->Y =  0.0 True TE Y->X =  0.0356\n",
      "Epoch [990/1000], Loss: 0.60000213\n",
      "Epoch [990/1000], Loss: 0.59975647\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.64449712\n",
      "Epoch [990/1000], Loss: 0.60909259\n",
      "TE: Y->X 0.0354\n",
      "# Coupling =  0.375 True TE X->Y =  0.0 True TE Y->X =  0.0763\n",
      "Epoch [990/1000], Loss: 0.60000004\n",
      "Epoch [990/1000], Loss: 0.59968157\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.68509194\n",
      "Epoch [990/1000], Loss: 0.60878105\n",
      "TE: Y->X 0.0763\n",
      "# Coupling =  0.5 True TE X->Y =  0.0 True TE Y->X =  0.1276\n",
      "Epoch [990/1000], Loss: 0.60002472\n",
      "Epoch [990/1000], Loss: 0.59960938\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.73611995\n",
      "Epoch [990/1000], Loss: 0.60929893\n",
      "TE: Y->X 0.1268\n",
      "# Coupling =  0.625 True TE X->Y =  0.0 True TE Y->X =  0.1861\n",
      "Epoch [990/1000], Loss: 0.60001305\n",
      "Epoch [990/1000], Loss: 0.59954482\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.79454127\n",
      "Epoch [990/1000], Loss: 0.60892726\n",
      "TE: Y->X 0.1856\n",
      "# Coupling =  0.75 True TE X->Y =  0.0 True TE Y->X =  0.249\n",
      "Epoch [990/1000], Loss: 0.60000304\n",
      "Epoch [990/1000], Loss: 0.59966401\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.85708609\n",
      "Epoch [990/1000], Loss: 0.60923831\n",
      "TE: Y->X 0.2478\n",
      "# Coupling =  0.875 True TE X->Y =  0.0 True TE Y->X =  0.3141\n",
      "Epoch [990/1000], Loss: 0.60000544\n",
      "Epoch [990/1000], Loss: 0.59956952\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.92209829\n",
      "Epoch [990/1000], Loss: 0.60912514\n",
      "TE: Y->X 0.313\n",
      "# Coupling =  1.0 True TE X->Y =  0.0 True TE Y->X =  0.3797\n",
      "Epoch [990/1000], Loss: 0.60000777\n",
      "Epoch [990/1000], Loss: 0.59960703\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.98763049\n",
      "Epoch [990/1000], Loss: 0.60974054\n",
      "TE: Y->X 0.3779\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "# Coupling =  0.0 True TE X->Y =  0.0 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 0.61114147\n",
      "Epoch [990/1000], Loss: 0.61087356\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.62309248\n",
      "Epoch [990/1000], Loss: 0.62191085\n",
      "TE: Y->X 0.0012\n",
      "# Coupling =  0.125 True TE X->Y =  0.0 True TE Y->X =  0.0092\n",
      "Epoch [990/1000], Loss: 0.61107127\n",
      "Epoch [990/1000], Loss: 0.61070848\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.62958085\n",
      "Epoch [990/1000], Loss: 0.62276927\n",
      "TE: Y->X 0.0068\n",
      "# Coupling =  0.25 True TE X->Y =  0.0 True TE Y->X =  0.0356\n",
      "Epoch [990/1000], Loss: 0.61106967\n",
      "Epoch [990/1000], Loss: 0.61065799\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.65341659\n",
      "Epoch [990/1000], Loss: 0.62170389\n",
      "TE: Y->X 0.0317\n",
      "# Coupling =  0.375 True TE X->Y =  0.0 True TE Y->X =  0.0763\n",
      "Epoch [990/1000], Loss: 0.61102871\n",
      "Epoch [990/1000], Loss: 0.61063831\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.69219328\n",
      "Epoch [990/1000], Loss: 0.62206263\n",
      "TE: Y->X 0.0701\n",
      "# Coupling =  0.5 True TE X->Y =  0.0 True TE Y->X =  0.1276\n",
      "Epoch [990/1000], Loss: 0.61101194\n",
      "Epoch [990/1000], Loss: 0.61045749\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.74082843\n",
      "Epoch [990/1000], Loss: 0.62256664\n",
      "TE: Y->X 0.1182\n",
      "# Coupling =  0.625 True TE X->Y =  0.0 True TE Y->X =  0.1861\n",
      "Epoch [990/1000], Loss: 0.61113621\n",
      "Epoch [990/1000], Loss: 0.61065485\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.79757773\n",
      "Epoch [990/1000], Loss: 0.62288861\n",
      "TE: Y->X 0.1747\n",
      "# Coupling =  0.75 True TE X->Y =  0.0 True TE Y->X =  0.249\n",
      "Epoch [990/1000], Loss: 0.61110789\n",
      "Epoch [990/1000], Loss: 0.61066634\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.85857898\n",
      "Epoch [990/1000], Loss: 0.62342852\n",
      "TE: Y->X 0.2351\n",
      "# Coupling =  0.875 True TE X->Y =  0.0 True TE Y->X =  0.3141\n",
      "Epoch [990/1000], Loss: 0.61112332\n",
      "Epoch [990/1000], Loss: 0.61062197\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.92251163\n",
      "Epoch [990/1000], Loss: 0.62346063\n",
      "TE: Y->X 0.299\n",
      "# Coupling =  1.0 True TE X->Y =  0.0 True TE Y->X =  0.3797\n",
      "Epoch [990/1000], Loss: 0.61106755\n",
      "Epoch [990/1000], Loss: 0.61072376\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.98713443\n",
      "Epoch [990/1000], Loss: 0.62386627\n",
      "TE: Y->X 0.3633\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "# Coupling =  0.0 True TE X->Y =  0.0 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 0.60890695\n",
      "Epoch [990/1000], Loss: 0.60881106\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.62659264\n",
      "Epoch [990/1000], Loss: 0.62647188\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  0.125 True TE X->Y =  0.0 True TE Y->X =  0.0092\n",
      "Epoch [990/1000], Loss: 0.60881081\n",
      "Epoch [990/1000], Loss: 0.60891169\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 0.63681487\n",
      "Epoch [990/1000], Loss: 0.62668719\n",
      "TE: Y->X 0.0101\n",
      "# Coupling =  0.25 True TE X->Y =  0.0 True TE Y->X =  0.0356\n",
      "Epoch [990/1000], Loss: 0.60884633\n",
      "Epoch [990/1000], Loss: 0.60895922\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 0.66313948\n",
      "Epoch [990/1000], Loss: 0.62678728\n",
      "TE: Y->X 0.0364\n",
      "# Coupling =  0.375 True TE X->Y =  0.0 True TE Y->X =  0.0763\n",
      "Epoch [990/1000], Loss: 0.60894428\n",
      "Epoch [990/1000], Loss: 0.60889414\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.70315151\n",
      "Epoch [990/1000], Loss: 0.62745423\n",
      "TE: Y->X 0.0757\n",
      "# Coupling =  0.5 True TE X->Y =  0.0 True TE Y->X =  0.1276\n",
      "Epoch [990/1000], Loss: 0.60878914\n",
      "Epoch [990/1000], Loss: 0.60899018\n",
      "TE: X->Y -0.0002\n",
      "Epoch [990/1000], Loss: 0.75319436\n",
      "Epoch [990/1000], Loss: 0.62727091\n",
      "TE: Y->X 0.1259\n",
      "# Coupling =  0.625 True TE X->Y =  0.0 True TE Y->X =  0.1861\n",
      "Epoch [990/1000], Loss: 0.60889079\n",
      "Epoch [990/1000], Loss: 0.60895116\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 0.81016859\n",
      "Epoch [990/1000], Loss: 0.62737282\n",
      "TE: Y->X 0.1828\n",
      "# Coupling =  0.75 True TE X->Y =  0.0 True TE Y->X =  0.249\n",
      "Epoch [990/1000], Loss: 0.60889241\n",
      "Epoch [990/1000], Loss: 0.60886918\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.87146857\n",
      "Epoch [990/1000], Loss: 0.62794854\n",
      "TE: Y->X 0.2435\n",
      "# Coupling =  0.875 True TE X->Y =  0.0 True TE Y->X =  0.3141\n",
      "Epoch [990/1000], Loss: 0.60883351\n",
      "Epoch [990/1000], Loss: 0.60880543\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.93470093\n",
      "Epoch [990/1000], Loss: 0.62779332\n",
      "TE: Y->X 0.3069\n",
      "# Coupling =  1.0 True TE X->Y =  0.0 True TE Y->X =  0.3797\n",
      "Epoch [990/1000], Loss: 0.60886399\n",
      "Epoch [990/1000], Loss: 0.60884655\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.99881639\n",
      "Epoch [990/1000], Loss: 0.62802459\n",
      "TE: Y->X 0.3708\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "# Coupling =  0.0 True TE X->Y =  0.0 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 0.61952315\n",
      "Epoch [990/1000], Loss: 0.61904074\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.59894236\n",
      "Epoch [990/1000], Loss: 0.59916292\n",
      "TE: Y->X -0.0002\n",
      "# Coupling =  0.125 True TE X->Y =  0.0 True TE Y->X =  0.0092\n",
      "Epoch [990/1000], Loss: 0.61952575\n",
      "Epoch [990/1000], Loss: 0.61913158\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.60572656\n",
      "Epoch [990/1000], Loss: 0.59916963\n",
      "TE: Y->X 0.0066\n",
      "# Coupling =  0.25 True TE X->Y =  0.0 True TE Y->X =  0.0356\n",
      "Epoch [990/1000], Loss: 0.61953692\n",
      "Epoch [990/1000], Loss: 0.61921746\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.63047765\n",
      "Epoch [990/1000], Loss: 0.59935952\n",
      "TE: Y->X 0.0311\n",
      "# Coupling =  0.375 True TE X->Y =  0.0 True TE Y->X =  0.0763\n",
      "Epoch [990/1000], Loss: 0.61949904\n",
      "Epoch [990/1000], Loss: 0.61917778\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.67045733\n",
      "Epoch [990/1000], Loss: 0.59932746\n",
      "TE: Y->X 0.0711\n",
      "# Coupling =  0.5 True TE X->Y =  0.0 True TE Y->X =  0.1276\n",
      "Epoch [990/1000], Loss: 0.61946423\n",
      "Epoch [990/1000], Loss: 0.61913152\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.72155003\n",
      "Epoch [990/1000], Loss: 0.60004416\n",
      "TE: Y->X 0.1215\n",
      "# Coupling =  0.625 True TE X->Y =  0.0 True TE Y->X =  0.1861\n",
      "Epoch [990/1000], Loss: 0.61961173\n",
      "Epoch [990/1000], Loss: 0.61910699\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.78031104\n",
      "Epoch [990/1000], Loss: 0.59951279\n",
      "TE: Y->X 0.1808\n",
      "# Coupling =  0.75 True TE X->Y =  0.0 True TE Y->X =  0.249\n",
      "Epoch [990/1000], Loss: 0.61953163\n",
      "Epoch [990/1000], Loss: 0.61904532\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.84401988\n",
      "Epoch [990/1000], Loss: 0.60027118\n",
      "TE: Y->X 0.2437\n",
      "# Coupling =  0.875 True TE X->Y =  0.0 True TE Y->X =  0.3141\n",
      "Epoch [990/1000], Loss: 0.61942402\n",
      "Epoch [990/1000], Loss: 0.61898306\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.91015017\n",
      "Epoch [990/1000], Loss: 0.60042245\n",
      "TE: Y->X 0.3097\n",
      "# Coupling =  1.0 True TE X->Y =  0.0 True TE Y->X =  0.3797\n",
      "Epoch [990/1000], Loss: 0.61950858\n",
      "Epoch [990/1000], Loss: 0.61892677\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 0.97681613\n",
      "Epoch [990/1000], Loss: 0.60101684\n",
      "TE: Y->X 0.3758\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "# Coupling =  0.0 True TE X->Y =  0.0 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 0.61373787\n",
      "Epoch [990/1000], Loss: 0.61320029\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 0.60458185\n",
      "Epoch [990/1000], Loss: 0.60461512\n",
      "TE: Y->X -0.0\n",
      "# Coupling =  0.125 True TE X->Y =  0.0 True TE Y->X =  0.0092\n",
      "Epoch [990/1000], Loss: 0.61358666\n",
      "Epoch [990/1000], Loss: 0.61321474\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.61453729\n",
      "Epoch [990/1000], Loss: 0.60484101\n",
      "TE: Y->X 0.0097\n",
      "# Coupling =  0.25 True TE X->Y =  0.0 True TE Y->X =  0.0356\n",
      "Epoch [990/1000], Loss: 0.61368774\n",
      "Epoch [990/1000], Loss: 0.61343755\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.64223671\n",
      "Epoch [990/1000], Loss: 0.60496067\n",
      "TE: Y->X 0.0373\n",
      "# Coupling =  0.375 True TE X->Y =  0.0 True TE Y->X =  0.0763\n",
      "Epoch [990/1000], Loss: 0.61367974\n",
      "Epoch [990/1000], Loss: 0.61355756\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.68430333\n",
      "Epoch [990/1000], Loss: 0.60459936\n",
      "TE: Y->X 0.0797\n",
      "# Coupling =  0.5 True TE X->Y =  0.0 True TE Y->X =  0.1276\n",
      "Epoch [990/1000], Loss: 0.61354098\n",
      "Epoch [990/1000], Loss: 0.61355221\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.73663669\n",
      "Epoch [990/1000], Loss: 0.60495615\n",
      "TE: Y->X 0.1317\n",
      "# Coupling =  0.625 True TE X->Y =  0.0 True TE Y->X =  0.1861\n",
      "Epoch [990/1000], Loss: 0.61364726\n",
      "Epoch [990/1000], Loss: 0.61358119\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.79636636\n",
      "Epoch [990/1000], Loss: 0.60526812\n",
      "TE: Y->X 0.1911\n",
      "# Coupling =  0.75 True TE X->Y =  0.0 True TE Y->X =  0.249\n",
      "Epoch [990/1000], Loss: 0.61368899\n",
      "Epoch [990/1000], Loss: 0.61364732\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.86007468\n",
      "Epoch [990/1000], Loss: 0.60513762\n",
      "TE: Y->X 0.2549\n",
      "# Coupling =  0.875 True TE X->Y =  0.0 True TE Y->X =  0.3141\n",
      "Epoch [990/1000], Loss: 0.61368497\n",
      "Epoch [990/1000], Loss: 0.61360389\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.92619146\n",
      "Epoch [990/1000], Loss: 0.60580542\n",
      "TE: Y->X 0.3203\n",
      "# Coupling =  1.0 True TE X->Y =  0.0 True TE Y->X =  0.3797\n",
      "Epoch [990/1000], Loss: 0.61321077\n",
      "Epoch [990/1000], Loss: 0.61353863\n",
      "TE: X->Y -0.0003\n",
      "Epoch [990/1000], Loss: 0.99264988\n",
      "Epoch [990/1000], Loss: 0.60590659\n",
      "TE: Y->X 0.3867\n"
     ]
    }
   ],
   "source": [
    "lg_results_TE_X2Y = Results(columns=['method', 'coupling'])\n",
    "lg_results_TE_Y2X = Results(columns=['method', 'coupling'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for lam, generator in zip(lg_lambda_range, lg_generator_lst):\n",
    "        print(\"# Coupling = \", lam, \"True TE X->Y = \", generator.analytic_transfer_entropy('X', 'Y'), \"True TE Y->X = \", generator.analytic_transfer_entropy('Y', 'X'))\n",
    "        # Simulate data\n",
    "        dataset = get_dataset(generator, SAMPLE_SIZE, BATCH_SIZE, seed=r)\n",
    "        # Estimate X -> Y\n",
    "        TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', 16, NB)\n",
    "        lg_results_TE_X2Y.write(method='agmte', coupling=lam, value=TE_X2Y)\n",
    "        # Estimate Y -> X\n",
    "        TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', 16, NB)\n",
    "        lg_results_TE_Y2X.write(method='agmte', coupling=lam, value=TE_Y2X)\n",
    "\n",
    "lg_results_TE_X2Y.df.to_csv('results/agmte/lg_results_TE_X2Y_bv.csv', index=False)\n",
    "lg_results_TE_Y2X.df.to_csv('results/agmte/lg_results_TE_Y2X_bv.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify the range of lambda values to test\n",
    "jp_lambda_range = list(np.linspace(-3, 3, 9, endpoint=True))\n",
    "\n",
    "# Initialize the list of generators with one for each lambda value\n",
    "jp_generator_lst = [MVJointProcessSimulator(n_dim=1, lam=lam) for lam in jp_lambda_range]\n",
    "\n",
    "# get the reference values\n",
    "jp_TE_X2Y_ref_lst = [generator.analytic_transfer_entropy('X', 'Y') for generator in jp_generator_lst]\n",
    "jp_TE_Y2X_ref_lst = [generator.analytic_transfer_entropy('Y', 'X') for generator in jp_generator_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "# Coupling =  -3.0 True TE X->Y =  0.8292 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40761787\n",
      "Epoch [990/1000], Loss: 0.57996737\n",
      "TE: X->Y 0.8276\n",
      "Epoch [990/1000], Loss: 1.41203045\n",
      "Epoch [990/1000], Loss: 1.41183854\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  -2.25 True TE X->Y =  0.8202 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40814165\n",
      "Epoch [990/1000], Loss: 0.59516223\n",
      "TE: X->Y 0.8127\n",
      "Epoch [990/1000], Loss: 1.41204264\n",
      "Epoch [990/1000], Loss: 1.41192353\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  -1.5 True TE X->Y =  0.7749 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40796404\n",
      "Epoch [990/1000], Loss: 0.64315374\n",
      "TE: X->Y 0.7644\n",
      "Epoch [990/1000], Loss: 1.41204413\n",
      "Epoch [990/1000], Loss: 1.41181412\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  -0.75 True TE X->Y =  0.6422 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40410725\n",
      "Epoch [990/1000], Loss: 0.77391904\n",
      "TE: X->Y 0.6298\n",
      "Epoch [990/1000], Loss: 1.41201511\n",
      "Epoch [990/1000], Loss: 1.41185322\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  0.0 True TE X->Y =  0.4152 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40366122\n",
      "Epoch [990/1000], Loss: 1.00561724\n",
      "TE: X->Y 0.3976\n",
      "Epoch [990/1000], Loss: 1.41206894\n",
      "Epoch [990/1000], Loss: 1.41190132\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  0.75 True TE X->Y =  0.1882 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41142624\n",
      "Epoch [990/1000], Loss: 1.23550136\n",
      "TE: X->Y 0.1756\n",
      "Epoch [990/1000], Loss: 1.41205766\n",
      "Epoch [990/1000], Loss: 1.41168939\n",
      "TE: Y->X 0.0004\n",
      "# Coupling =  1.5 True TE X->Y =  0.0555 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40793718\n",
      "Epoch [990/1000], Loss: 1.38940065\n",
      "TE: X->Y 0.0181\n",
      "Epoch [990/1000], Loss: 1.41200132\n",
      "Epoch [990/1000], Loss: 1.41168187\n",
      "TE: Y->X 0.0003\n",
      "# Coupling =  2.25 True TE X->Y =  0.0102 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40460621\n",
      "Epoch [990/1000], Loss: 1.40436921\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 1.41205055\n",
      "Epoch [990/1000], Loss: 1.41172736\n",
      "TE: Y->X 0.0003\n",
      "# Coupling =  3.0 True TE X->Y =  0.0011 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.40634387\n",
      "Epoch [990/1000], Loss: 1.40633318\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 1.41201812\n",
      "Epoch [990/1000], Loss: 1.41177116\n",
      "TE: Y->X 0.0002\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "# Coupling =  -3.0 True TE X->Y =  0.8292 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42480551\n",
      "Epoch [990/1000], Loss: 0.59282341\n",
      "TE: X->Y 0.8319\n",
      "Epoch [990/1000], Loss: 1.41976307\n",
      "Epoch [990/1000], Loss: 1.41971612\n",
      "TE: Y->X 0.0\n",
      "# Coupling =  -2.25 True TE X->Y =  0.8202 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42410675\n",
      "Epoch [990/1000], Loss: 0.60879664\n",
      "TE: X->Y 0.8151\n",
      "Epoch [990/1000], Loss: 1.41946345\n",
      "Epoch [990/1000], Loss: 1.41975783\n",
      "TE: Y->X -0.0003\n",
      "# Coupling =  -1.5 True TE X->Y =  0.7749 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42435932\n",
      "Epoch [990/1000], Loss: 0.65773594\n",
      "TE: X->Y 0.7662\n",
      "Epoch [990/1000], Loss: 1.41958011\n",
      "Epoch [990/1000], Loss: 1.41967131\n",
      "TE: Y->X -0.0001\n",
      "# Coupling =  -0.75 True TE X->Y =  0.6422 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42633534\n",
      "Epoch [990/1000], Loss: 0.78845249\n",
      "TE: X->Y 0.6373\n",
      "Epoch [990/1000], Loss: 1.41944928\n",
      "Epoch [990/1000], Loss: 1.41964769\n",
      "TE: Y->X -0.0002\n",
      "# Coupling =  0.0 True TE X->Y =  0.4152 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42555631\n",
      "Epoch [990/1000], Loss: 1.00688852\n",
      "TE: X->Y 0.4184\n",
      "Epoch [990/1000], Loss: 1.41938289\n",
      "Epoch [990/1000], Loss: 1.41964843\n",
      "TE: Y->X -0.0003\n",
      "# Coupling =  0.75 True TE X->Y =  0.1882 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42042563\n",
      "Epoch [990/1000], Loss: 1.24467387\n",
      "TE: X->Y 0.1753\n",
      "Epoch [990/1000], Loss: 1.41945796\n",
      "Epoch [990/1000], Loss: 1.41937346\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  1.5 True TE X->Y =  0.0555 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41693026\n",
      "Epoch [990/1000], Loss: 1.38459948\n",
      "TE: X->Y 0.0309\n",
      "Epoch [990/1000], Loss: 1.41945713\n",
      "Epoch [990/1000], Loss: 1.41963879\n",
      "TE: Y->X -0.0002\n",
      "# Coupling =  2.25 True TE X->Y =  0.0102 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41741052\n",
      "Epoch [990/1000], Loss: 1.41673977\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 1.41942644\n",
      "Epoch [990/1000], Loss: 1.41944284\n",
      "TE: Y->X -0.0\n",
      "# Coupling =  3.0 True TE X->Y =  0.0011 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41746505\n",
      "Epoch [990/1000], Loss: 1.41731971\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 1.41973366\n",
      "Epoch [990/1000], Loss: 1.41965663\n",
      "TE: Y->X 0.0001\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "# Coupling =  -3.0 True TE X->Y =  0.8292 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42304359\n",
      "Epoch [990/1000], Loss: 0.59018286\n",
      "TE: X->Y 0.8328\n",
      "Epoch [990/1000], Loss: 1.42695906\n",
      "Epoch [990/1000], Loss: 1.42681448\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  -2.25 True TE X->Y =  0.8202 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42402534\n",
      "Epoch [990/1000], Loss: 0.60783176\n",
      "TE: X->Y 0.8159\n",
      "Epoch [990/1000], Loss: 1.42689914\n",
      "Epoch [990/1000], Loss: 1.42677039\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  -1.5 True TE X->Y =  0.7749 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42297741\n",
      "Epoch [990/1000], Loss: 0.65377094\n",
      "TE: X->Y 0.7689\n",
      "Epoch [990/1000], Loss: 1.42690597\n",
      "Epoch [990/1000], Loss: 1.42688208\n",
      "TE: Y->X 0.0\n",
      "# Coupling =  -0.75 True TE X->Y =  0.6422 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42143773\n",
      "Epoch [990/1000], Loss: 0.79157102\n",
      "TE: X->Y 0.6295\n",
      "Epoch [990/1000], Loss: 1.42694751\n",
      "Epoch [990/1000], Loss: 1.42666346\n",
      "TE: Y->X 0.0003\n",
      "# Coupling =  0.0 True TE X->Y =  0.4152 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.42634393\n",
      "Epoch [990/1000], Loss: 1.01568211\n",
      "TE: X->Y 0.4102\n",
      "Epoch [990/1000], Loss: 1.42695107\n",
      "Epoch [990/1000], Loss: 1.42680724\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  0.75 True TE X->Y =  0.1882 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.43046574\n",
      "Epoch [990/1000], Loss: 1.24766375\n",
      "TE: X->Y 0.1823\n",
      "Epoch [990/1000], Loss: 1.42693599\n",
      "Epoch [990/1000], Loss: 1.42670345\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  1.5 True TE X->Y =  0.0555 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41971754\n",
      "Epoch [990/1000], Loss: 1.38730121\n",
      "TE: X->Y 0.031\n",
      "Epoch [990/1000], Loss: 1.42691186\n",
      "Epoch [990/1000], Loss: 1.42658214\n",
      "TE: Y->X 0.0003\n",
      "# Coupling =  2.25 True TE X->Y =  0.0102 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41938817\n",
      "Epoch [990/1000], Loss: 1.41922199\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 1.42695616\n",
      "Epoch [990/1000], Loss: 1.42642085\n",
      "TE: Y->X 0.0005\n",
      "# Coupling =  3.0 True TE X->Y =  0.0011 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41937968\n",
      "Epoch [990/1000], Loss: 1.41950717\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 1.42689946\n",
      "Epoch [990/1000], Loss: 1.42657399\n",
      "TE: Y->X 0.0003\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "# Coupling =  -3.0 True TE X->Y =  0.8292 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41146431\n",
      "Epoch [990/1000], Loss: 0.59042616\n",
      "TE: X->Y 0.821\n",
      "Epoch [990/1000], Loss: 1.41297606\n",
      "Epoch [990/1000], Loss: 1.41269863\n",
      "TE: Y->X 0.0003\n",
      "# Coupling =  -2.25 True TE X->Y =  0.8202 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41045916\n",
      "Epoch [990/1000], Loss: 0.60437639\n",
      "TE: X->Y 0.8058\n",
      "Epoch [990/1000], Loss: 1.41296979\n",
      "Epoch [990/1000], Loss: 1.41278461\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  -1.5 True TE X->Y =  0.7749 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41135625\n",
      "Epoch [990/1000], Loss: 0.65864933\n",
      "TE: X->Y 0.7522\n",
      "Epoch [990/1000], Loss: 1.41294996\n",
      "Epoch [990/1000], Loss: 1.41235605\n",
      "TE: Y->X 0.0006\n",
      "# Coupling =  -0.75 True TE X->Y =  0.6422 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41228111\n",
      "Epoch [990/1000], Loss: 0.79201596\n",
      "TE: X->Y 0.6197\n",
      "Epoch [990/1000], Loss: 1.41296375\n",
      "Epoch [990/1000], Loss: 1.41278059\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  0.0 True TE X->Y =  0.4152 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41523381\n",
      "Epoch [990/1000], Loss: 1.01778427\n",
      "TE: X->Y 0.3971\n",
      "Epoch [990/1000], Loss: 1.41297384\n",
      "Epoch [990/1000], Loss: 1.41246955\n",
      "TE: Y->X 0.0005\n",
      "# Coupling =  0.75 True TE X->Y =  0.1882 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41668825\n",
      "Epoch [990/1000], Loss: 1.24534028\n",
      "TE: X->Y 0.171\n",
      "Epoch [990/1000], Loss: 1.41291791\n",
      "Epoch [990/1000], Loss: 1.41268806\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  1.5 True TE X->Y =  0.0555 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41377548\n",
      "Epoch [990/1000], Loss: 1.39561996\n",
      "TE: X->Y 0.0175\n",
      "Epoch [990/1000], Loss: 1.41298945\n",
      "Epoch [990/1000], Loss: 1.41292038\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  2.25 True TE X->Y =  0.0102 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41361239\n",
      "Epoch [990/1000], Loss: 1.41331422\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 1.41291164\n",
      "Epoch [990/1000], Loss: 1.41290518\n",
      "TE: Y->X 0.0\n",
      "# Coupling =  3.0 True TE X->Y =  0.0011 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41433193\n",
      "Epoch [990/1000], Loss: 1.41417326\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 1.41300797\n",
      "Epoch [990/1000], Loss: 1.41286412\n",
      "TE: Y->X 0.0001\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "# Coupling =  -3.0 True TE X->Y =  0.8292 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41601429\n",
      "Epoch [990/1000], Loss: 0.58626415\n",
      "TE: X->Y 0.8297\n",
      "Epoch [990/1000], Loss: 1.41262157\n",
      "Epoch [990/1000], Loss: 1.41248428\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  -2.25 True TE X->Y =  0.8202 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41651998\n",
      "Epoch [990/1000], Loss: 0.60434508\n",
      "TE: X->Y 0.8119\n",
      "Epoch [990/1000], Loss: 1.41269617\n",
      "Epoch [990/1000], Loss: 1.41253006\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  -1.5 True TE X->Y =  0.7749 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41625025\n",
      "Epoch [990/1000], Loss: 0.65143481\n",
      "TE: X->Y 0.7644\n",
      "Epoch [990/1000], Loss: 1.41266461\n",
      "Epoch [990/1000], Loss: 1.41257209\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  -0.75 True TE X->Y =  0.6422 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41595628\n",
      "Epoch [990/1000], Loss: 0.78277323\n",
      "TE: X->Y 0.6328\n",
      "Epoch [990/1000], Loss: 1.41263967\n",
      "Epoch [990/1000], Loss: 1.41241889\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  0.0 True TE X->Y =  0.4152 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41028456\n",
      "Epoch [990/1000], Loss: 1.00139599\n",
      "TE: X->Y 0.4086\n",
      "Epoch [990/1000], Loss: 1.41276408\n",
      "Epoch [990/1000], Loss: 1.41253612\n",
      "TE: Y->X 0.0002\n",
      "# Coupling =  0.75 True TE X->Y =  0.1882 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41550823\n",
      "Epoch [990/1000], Loss: 1.23530771\n",
      "TE: X->Y 0.1799\n",
      "Epoch [990/1000], Loss: 1.41264556\n",
      "Epoch [990/1000], Loss: 1.41262423\n",
      "TE: Y->X 0.0\n",
      "# Coupling =  1.5 True TE X->Y =  0.0555 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41410326\n",
      "Epoch [990/1000], Loss: 1.38806712\n",
      "TE: X->Y 0.0241\n",
      "Epoch [990/1000], Loss: 1.41263594\n",
      "Epoch [990/1000], Loss: 1.41257399\n",
      "TE: Y->X 0.0001\n",
      "# Coupling =  2.25 True TE X->Y =  0.0102 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41304278\n",
      "Epoch [990/1000], Loss: 1.41220945\n",
      "TE: X->Y 0.0008\n",
      "Epoch [990/1000], Loss: 1.41264503\n",
      "Epoch [990/1000], Loss: 1.41266578\n",
      "TE: Y->X -0.0\n",
      "# Coupling =  3.0 True TE X->Y =  0.0011 True TE Y->X =  0.0\n",
      "Epoch [990/1000], Loss: 1.41290747\n",
      "Epoch [990/1000], Loss: 1.41306601\n",
      "TE: X->Y -0.0002\n",
      "Epoch [990/1000], Loss: 1.41265087\n",
      "Epoch [990/1000], Loss: 1.41252666\n",
      "TE: Y->X 0.0001\n"
     ]
    }
   ],
   "source": [
    "jp_results_TE_X2Y = Results(columns=['method', 'coupling'])\n",
    "jp_results_TE_Y2X = Results(columns=['method', 'coupling'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for lam, generator in zip(jp_lambda_range, jp_generator_lst):\n",
    "        print(\"# Coupling = \", lam, \"True TE X->Y = \", generator.analytic_transfer_entropy('X', 'Y'), \"True TE Y->X = \", generator.analytic_transfer_entropy('Y', 'X'))\n",
    "        # Simulate data\n",
    "        dataset = get_dataset(generator, SAMPLE_SIZE, BATCH_SIZE, seed=r)\n",
    "        # Estimate X -> Y\n",
    "        TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', 16, NB)\n",
    "        jp_results_TE_X2Y.write(method='agmte', coupling=lam, value=TE_X2Y)\n",
    "        # Estimate Y -> X\n",
    "        TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', 16, NB)\n",
    "        jp_results_TE_Y2X.write(method='agmte', coupling=lam, value=TE_Y2X)\n",
    "\n",
    "jp_results_TE_X2Y.df.to_csv('results/agmte/jp_results_TE_X2Y_bv.csv', index=False)\n",
    "jp_results_TE_Y2X.df.to_csv('results/agmte/jp_results_TE_Y2X_bv.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sample size scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "lg_generator = MVLinearGaussianSimulator(n_dim=1, coupling=0.5)\n",
    "jp_generator = MVJointProcessSimulator(n_dim=1, lam=0.0)\n",
    "sample_sizes = [500, 1000, 5000, 10000, 50000, 100000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "# Samples =  500 #\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [990/1000], Loss: 0.57833002\n",
      "Epoch [990/1000], Loss: 0.56550968\n",
      "TE: X->Y 0.0126\n",
      "Epoch [990/1000], Loss: 0.68325674\n",
      "Epoch [990/1000], Loss: 0.56250367\n",
      "TE: Y->X 0.1208\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 0.59310173\n",
      "Epoch [990/1000], Loss: 0.58539343\n",
      "TE: X->Y 0.0073\n",
      "Epoch [990/1000], Loss: 0.70054152\n",
      "Epoch [990/1000], Loss: 0.58913923\n",
      "TE: Y->X 0.1114\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 0.59610704\n",
      "Epoch [990/1000], Loss: 0.59592941\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.73392021\n",
      "Epoch [990/1000], Loss: 0.60799118\n",
      "TE: Y->X 0.1259\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 0.60006024\n",
      "Epoch [990/1000], Loss: 0.59962094\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.73606497\n",
      "Epoch [990/1000], Loss: 0.60894982\n",
      "TE: Y->X 0.1271\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 0.60970547\n",
      "Epoch [990/1000], Loss: 0.60960642\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.73881574\n",
      "Epoch [990/1000], Loss: 0.61453434\n",
      "TE: Y->X 0.1243\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61212491\n",
      "Epoch [990/1000], Loss: 0.61209932\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73669941\n",
      "Epoch [990/1000], Loss: 0.61238337\n",
      "TE: Y->X 0.1243\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 0.59964303\n",
      "Epoch [990/1000], Loss: 0.58727952\n",
      "TE: X->Y 0.0121\n",
      "Epoch [990/1000], Loss: 0.74090166\n",
      "Epoch [990/1000], Loss: 0.62327966\n",
      "TE: Y->X 0.1173\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 0.60848286\n",
      "Epoch [990/1000], Loss: 0.60768502\n",
      "TE: X->Y 0.0008\n",
      "Epoch [990/1000], Loss: 0.76036755\n",
      "Epoch [990/1000], Loss: 0.63098147\n",
      "TE: Y->X 0.1293\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 0.60861856\n",
      "Epoch [990/1000], Loss: 0.60724579\n",
      "TE: X->Y 0.0014\n",
      "Epoch [990/1000], Loss: 0.74853398\n",
      "Epoch [990/1000], Loss: 0.62036422\n",
      "TE: Y->X 0.1282\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61111611\n",
      "Epoch [990/1000], Loss: 0.61053923\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 0.74071111\n",
      "Epoch [990/1000], Loss: 0.62299011\n",
      "TE: Y->X 0.1177\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 0.60963732\n",
      "Epoch [990/1000], Loss: 0.60965548\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.74217783\n",
      "Epoch [990/1000], Loss: 0.61576059\n",
      "TE: Y->X 0.1264\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 0.60906663\n",
      "Epoch [990/1000], Loss: 0.60912842\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 0.74196433\n",
      "Epoch [990/1000], Loss: 0.61621067\n",
      "TE: Y->X 0.1257\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 0.59465663\n",
      "Epoch [990/1000], Loss: 0.58017654\n",
      "TE: X->Y 0.0125\n",
      "Epoch [990/1000], Loss: 0.73101924\n",
      "Epoch [990/1000], Loss: 0.55120432\n",
      "TE: Y->X 0.1788\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 0.59356952\n",
      "Epoch [990/1000], Loss: 0.59307998\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 0.73461671\n",
      "Epoch [990/1000], Loss: 0.58904747\n",
      "TE: Y->X 0.1455\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 0.61655529\n",
      "Epoch [990/1000], Loss: 0.61631458\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.74305539\n",
      "Epoch [990/1000], Loss: 0.61246823\n",
      "TE: Y->X 0.1305\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 0.60888164\n",
      "Epoch [990/1000], Loss: 0.60895429\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 0.75301016\n",
      "Epoch [990/1000], Loss: 0.62715567\n",
      "TE: Y->X 0.1258\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 0.61615435\n",
      "Epoch [990/1000], Loss: 0.61613523\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.74402177\n",
      "Epoch [990/1000], Loss: 0.61651261\n",
      "TE: Y->X 0.1275\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61183955\n",
      "Epoch [990/1000], Loss: 0.61181348\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73916752\n",
      "Epoch [990/1000], Loss: 0.61317908\n",
      "TE: Y->X 0.126\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 0.65247409\n",
      "Epoch [990/1000], Loss: 0.63311015\n",
      "TE: X->Y 0.0192\n",
      "Epoch [990/1000], Loss: 0.68984786\n",
      "Epoch [990/1000], Loss: 0.60473653\n",
      "TE: Y->X 0.0851\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 0.64159791\n",
      "Epoch [990/1000], Loss: 0.63541603\n",
      "TE: X->Y 0.0061\n",
      "Epoch [990/1000], Loss: 0.70581248\n",
      "Epoch [990/1000], Loss: 0.59917242\n",
      "TE: Y->X 0.1067\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 0.61562084\n",
      "Epoch [990/1000], Loss: 0.61500019\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 0.72083794\n",
      "Epoch [990/1000], Loss: 0.60499734\n",
      "TE: Y->X 0.1158\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61945858\n",
      "Epoch [990/1000], Loss: 0.61920773\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.72140443\n",
      "Epoch [990/1000], Loss: 0.59962377\n",
      "TE: Y->X 0.1218\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 0.61254724\n",
      "Epoch [990/1000], Loss: 0.61254435\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73554085\n",
      "Epoch [990/1000], Loss: 0.61041978\n",
      "TE: Y->X 0.1251\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61094611\n",
      "Epoch [990/1000], Loss: 0.61094754\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.73681116\n",
      "Epoch [990/1000], Loss: 0.61106572\n",
      "TE: Y->X 0.1257\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 0.60707017\n",
      "Epoch [990/1000], Loss: 0.59321602\n",
      "TE: X->Y 0.0137\n",
      "Epoch [990/1000], Loss: 0.70788527\n",
      "Epoch [990/1000], Loss: 0.53550183\n",
      "TE: Y->X 0.1719\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 0.60443727\n",
      "Epoch [990/1000], Loss: 0.60476529\n",
      "TE: X->Y -0.0003\n",
      "Epoch [990/1000], Loss: 0.70992387\n",
      "Epoch [990/1000], Loss: 0.54412854\n",
      "TE: Y->X 0.1656\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 0.62084543\n",
      "Epoch [990/1000], Loss: 0.62002926\n",
      "TE: X->Y 0.0008\n",
      "Epoch [990/1000], Loss: 0.72887997\n",
      "Epoch [990/1000], Loss: 0.59820472\n",
      "TE: Y->X 0.1307\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61368799\n",
      "Epoch [990/1000], Loss: 0.61353421\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.73668577\n",
      "Epoch [990/1000], Loss: 0.60474968\n",
      "TE: Y->X 0.1319\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 0.60888546\n",
      "Epoch [990/1000], Loss: 0.60887684\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73991366\n",
      "Epoch [990/1000], Loss: 0.61430036\n",
      "TE: Y->X 0.1256\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61019419\n",
      "Epoch [990/1000], Loss: 0.61024377\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.74375026\n",
      "Epoch [990/1000], Loss: 0.61724661\n",
      "TE: Y->X 0.1265\n"
     ]
    }
   ],
   "source": [
    "lg_results_TE_X2Y = Results(columns=['method', 'sample_size'])\n",
    "lg_results_TE_Y2X = Results(columns=['method', 'sample_size'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for samples in sample_sizes:\n",
    "        print(\"# Samples = \", samples, \"#\")\n",
    "        # Simulate data\n",
    "        dataset = get_dataset(lg_generator, samples, int(np.round(samples/20)), seed=r)\n",
    "        # Estimate X -> Y\n",
    "        TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', 16, NB)\n",
    "        lg_results_TE_X2Y.write(method='agmte', sample_size=samples, value=TE_X2Y)\n",
    "        # Estimate Y -> X\n",
    "        TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', 16, NB)\n",
    "        lg_results_TE_Y2X.write(method='agmte', sample_size=samples, value=TE_Y2X)\n",
    "\n",
    "lg_results_TE_X2Y.df.to_csv('results/agmte/lg_results_TE_X2Y_ss.csv', index=False)\n",
    "lg_results_TE_Y2X.df.to_csv('results/agmte/lg_results_TE_Y2X_ss.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 1.37343058\n",
      "Epoch [990/1000], Loss: 0.97759535\n",
      "TE: X->Y 0.394\n",
      "Epoch [990/1000], Loss: 1.38815508\n",
      "Epoch [990/1000], Loss: 1.38908573\n",
      "TE: Y->X -0.0011\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 1.37973456\n",
      "Epoch [990/1000], Loss: 0.99105465\n",
      "TE: X->Y 0.3881\n",
      "Epoch [990/1000], Loss: 1.38568337\n",
      "Epoch [990/1000], Loss: 1.38173084\n",
      "TE: Y->X 0.0038\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 1.40149035\n",
      "Epoch [990/1000], Loss: 1.00933048\n",
      "TE: X->Y 0.3917\n",
      "Epoch [990/1000], Loss: 1.40969162\n",
      "Epoch [990/1000], Loss: 1.40762321\n",
      "TE: Y->X 0.0021\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 1.40369313\n",
      "Epoch [990/1000], Loss: 1.00713383\n",
      "TE: X->Y 0.3962\n",
      "Epoch [990/1000], Loss: 1.41201619\n",
      "Epoch [990/1000], Loss: 1.41188147\n",
      "TE: Y->X 0.0001\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 1.41717659\n",
      "Epoch [990/1000], Loss: 1.01247535\n",
      "TE: X->Y 0.4044\n",
      "Epoch [990/1000], Loss: 1.41689053\n",
      "Epoch [990/1000], Loss: 1.41689813\n",
      "TE: Y->X -0.0\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41730554\n",
      "Epoch [990/1000], Loss: 1.01233903\n",
      "TE: X->Y 0.4046\n",
      "Epoch [990/1000], Loss: 1.41759627\n",
      "Epoch [990/1000], Loss: 1.41756569\n",
      "TE: Y->X 0.0\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 1.38018054\n",
      "Epoch [990/1000], Loss: 0.94187855\n",
      "TE: X->Y 0.4367\n",
      "Epoch [990/1000], Loss: 1.38730868\n",
      "Epoch [990/1000], Loss: 1.37575037\n",
      "TE: Y->X 0.0111\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 1.43584055\n",
      "Epoch [990/1000], Loss: 0.97678378\n",
      "TE: X->Y 0.4584\n",
      "Epoch [990/1000], Loss: 1.44530811\n",
      "Epoch [990/1000], Loss: 1.44009645\n",
      "TE: Y->X 0.0051\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 1.40958586\n",
      "Epoch [990/1000], Loss: 1.00169355\n",
      "TE: X->Y 0.4076\n",
      "Epoch [990/1000], Loss: 1.41293136\n",
      "Epoch [990/1000], Loss: 1.41269762\n",
      "TE: Y->X 0.0002\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 1.42534407\n",
      "Epoch [990/1000], Loss: 1.00742172\n",
      "TE: X->Y 0.4176\n",
      "Epoch [990/1000], Loss: 1.41953098\n",
      "Epoch [990/1000], Loss: 1.41965695\n",
      "TE: Y->X -0.0001\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 1.41293392\n",
      "Epoch [990/1000], Loss: 1.01361627\n",
      "TE: X->Y 0.399\n",
      "Epoch [990/1000], Loss: 1.41497576\n",
      "Epoch [990/1000], Loss: 1.41491925\n",
      "TE: Y->X 0.0001\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41705467\n",
      "Epoch [990/1000], Loss: 1.01303751\n",
      "TE: X->Y 0.4036\n",
      "Epoch [990/1000], Loss: 1.41591889\n",
      "Epoch [990/1000], Loss: 1.41587557\n",
      "TE: Y->X 0.0\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 1.46487303\n",
      "Epoch [990/1000], Loss: 0.99739558\n",
      "TE: X->Y 0.4667\n",
      "Epoch [990/1000], Loss: 1.38550033\n",
      "Epoch [990/1000], Loss: 1.37844434\n",
      "TE: Y->X 0.0066\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 1.41784685\n",
      "Epoch [990/1000], Loss: 1.01376958\n",
      "TE: X->Y 0.4032\n",
      "Epoch [990/1000], Loss: 1.38951332\n",
      "Epoch [990/1000], Loss: 1.38447951\n",
      "TE: Y->X 0.0046\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 1.41580538\n",
      "Epoch [990/1000], Loss: 1.03315323\n",
      "TE: X->Y 0.3823\n",
      "Epoch [990/1000], Loss: 1.41422029\n",
      "Epoch [990/1000], Loss: 1.41391675\n",
      "TE: Y->X 0.0003\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 1.42628643\n",
      "Epoch [990/1000], Loss: 1.01558849\n",
      "TE: X->Y 0.4104\n",
      "Epoch [990/1000], Loss: 1.42693346\n",
      "Epoch [990/1000], Loss: 1.42681942\n",
      "TE: Y->X 0.0001\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 1.42238432\n",
      "Epoch [990/1000], Loss: 1.02021185\n",
      "TE: X->Y 0.4018\n",
      "Epoch [990/1000], Loss: 1.42192739\n",
      "Epoch [990/1000], Loss: 1.42185454\n",
      "TE: Y->X 0.0001\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41404079\n",
      "Epoch [990/1000], Loss: 1.01950417\n",
      "TE: X->Y 0.3942\n",
      "Epoch [990/1000], Loss: 1.41313888\n",
      "Epoch [990/1000], Loss: 1.41312641\n",
      "TE: Y->X 0.0\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 1.39906516\n",
      "Epoch [990/1000], Loss: 0.97933692\n",
      "TE: X->Y 0.4167\n",
      "Epoch [990/1000], Loss: 1.39419569\n",
      "Epoch [990/1000], Loss: 1.35911538\n",
      "TE: Y->X 0.0337\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 1.45379896\n",
      "Epoch [990/1000], Loss: 1.01208054\n",
      "TE: X->Y 0.4412\n",
      "Epoch [990/1000], Loss: 1.42826103\n",
      "Epoch [990/1000], Loss: 1.41270332\n",
      "TE: Y->X 0.0154\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 1.40928032\n",
      "Epoch [990/1000], Loss: 1.02655728\n",
      "TE: X->Y 0.3824\n",
      "Epoch [990/1000], Loss: 1.41528217\n",
      "Epoch [990/1000], Loss: 1.41508825\n",
      "TE: Y->X 0.0002\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 1.41518035\n",
      "Epoch [990/1000], Loss: 1.01767777\n",
      "TE: X->Y 0.3972\n",
      "Epoch [990/1000], Loss: 1.41298316\n",
      "Epoch [990/1000], Loss: 1.41293045\n",
      "TE: Y->X 0.0\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 1.41515914\n",
      "Epoch [990/1000], Loss: 1.01472489\n",
      "TE: X->Y 0.4001\n",
      "Epoch [990/1000], Loss: 1.41478164\n",
      "Epoch [990/1000], Loss: 1.41476725\n",
      "TE: Y->X 0.0\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41542348\n",
      "Epoch [990/1000], Loss: 1.01546019\n",
      "TE: X->Y 0.3997\n",
      "Epoch [990/1000], Loss: 1.41517484\n",
      "Epoch [990/1000], Loss: 1.41517326\n",
      "TE: Y->X 0.0\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "# Samples =  500 #\n",
      "Epoch [990/1000], Loss: 1.34519801\n",
      "Epoch [990/1000], Loss: 0.91980152\n",
      "TE: X->Y 0.4226\n",
      "Epoch [990/1000], Loss: 1.35252527\n",
      "Epoch [990/1000], Loss: 1.33957457\n",
      "TE: Y->X 0.0126\n",
      "# Samples =  1000 #\n",
      "Epoch [990/1000], Loss: 1.40169731\n",
      "Epoch [990/1000], Loss: 0.93925119\n",
      "TE: X->Y 0.4614\n",
      "Epoch [990/1000], Loss: 1.40127861\n",
      "Epoch [990/1000], Loss: 1.39675798\n",
      "TE: Y->X 0.0044\n",
      "# Samples =  5000 #\n",
      "Epoch [990/1000], Loss: 1.40406979\n",
      "Epoch [990/1000], Loss: 0.99375387\n",
      "TE: X->Y 0.4098\n",
      "Epoch [990/1000], Loss: 1.42907145\n",
      "Epoch [990/1000], Loss: 1.42873495\n",
      "TE: Y->X 0.0003\n",
      "# Samples =  10000 #\n",
      "Epoch [990/1000], Loss: 1.41025479\n",
      "Epoch [990/1000], Loss: 1.00244278\n",
      "TE: X->Y 0.4075\n",
      "Epoch [990/1000], Loss: 1.41272718\n",
      "Epoch [990/1000], Loss: 1.41239641\n",
      "TE: Y->X 0.0003\n",
      "# Samples =  50000 #\n",
      "Epoch [990/1000], Loss: 1.41226948\n",
      "Epoch [990/1000], Loss: 1.01508456\n",
      "TE: X->Y 0.3967\n",
      "Epoch [990/1000], Loss: 1.41580425\n",
      "Epoch [990/1000], Loss: 1.41580618\n",
      "TE: Y->X -0.0\n",
      "# Samples =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41876505\n",
      "Epoch [990/1000], Loss: 1.01397216\n",
      "TE: X->Y 0.4044\n",
      "Epoch [990/1000], Loss: 1.42021605\n",
      "Epoch [990/1000], Loss: 1.42022877\n",
      "TE: Y->X -0.0\n"
     ]
    }
   ],
   "source": [
    "jp_results_TE_X2Y = Results(columns=['method', 'sample_size'])\n",
    "jp_results_TE_Y2X = Results(columns=['method', 'sample_size'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for samples in sample_sizes:\n",
    "        print(\"# Samples = \", samples, \"#\")\n",
    "        # Simulate data\n",
    "        dataset = get_dataset(jp_generator, samples, int(np.round(samples/20)), seed=r)\n",
    "        # Estimate X -> Y\n",
    "        TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', 16, NB)\n",
    "        jp_results_TE_X2Y.write(method='agmte', sample_size=samples, value=TE_X2Y)\n",
    "        # Estimate Y -> X\n",
    "        TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', 16, NB)\n",
    "        jp_results_TE_Y2X.write(method='agmte', sample_size=samples, value=TE_Y2X)\n",
    "\n",
    "jp_results_TE_X2Y.df.to_csv('results/agmte/jp_results_TE_X2Y_ss.csv', index=False)\n",
    "jp_results_TE_Y2X.df.to_csv('results/agmte/jp_results_TE_Y2X_ss.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dimensionality Scaling with redundant dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_range = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n",
    "sample_sizes = [10000, 100000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear Gaussian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the list of generators with one for each dimension\n",
    "lg_generator_lst = [MVLinearGaussianSimulator(n_dim=dim, n_redundant_dim=dim-1) for dim in dim_range]\n",
    "# Get the reference values\n",
    "lg_TE_X2Y_ref_lst = [generator.analytic_transfer_entropy('X', 'Y') for generator in lg_generator_lst]\n",
    "lg_TE_Y2X_ref_lst = [generator.analytic_transfer_entropy('Y', 'X') for generator in lg_generator_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.59994299\n",
      "Epoch [990/1000], Loss: 0.59954118\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.73560051\n",
      "Epoch [990/1000], Loss: 0.60894853\n",
      "TE: Y->X 0.1266\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61220766\n",
      "Epoch [990/1000], Loss: 0.61222325\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.73679992\n",
      "Epoch [990/1000], Loss: 0.61262013\n",
      "TE: Y->X 0.1242\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.02630841\n",
      "Epoch [990/1000], Loss: 2.02473243\n",
      "TE: X->Y 0.0015\n",
      "Epoch [990/1000], Loss: 2.15166856\n",
      "Epoch [990/1000], Loss: 2.02456357\n",
      "TE: Y->X 0.1271\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.02991453\n",
      "Epoch [990/1000], Loss: 2.02986212\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 2.15635821\n",
      "Epoch [990/1000], Loss: 2.03223515\n",
      "TE: Y->X 0.1241\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.43234438\n",
      "Epoch [990/1000], Loss: 3.42852699\n",
      "TE: X->Y 0.0037\n",
      "Epoch [990/1000], Loss: 3.57863688\n",
      "Epoch [990/1000], Loss: 3.44956993\n",
      "TE: Y->X 0.129\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.45008581\n",
      "Epoch [990/1000], Loss: 3.44993359\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 3.57420527\n",
      "Epoch [990/1000], Loss: 3.44995973\n",
      "TE: Y->X 0.1242\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.83731027\n",
      "Epoch [990/1000], Loss: 4.83171343\n",
      "TE: X->Y 0.0054\n",
      "Epoch [990/1000], Loss: 4.99322946\n",
      "Epoch [990/1000], Loss: 4.86184636\n",
      "TE: Y->X 0.1313\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.86715336\n",
      "Epoch [990/1000], Loss: 4.86678358\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 4.99461423\n",
      "Epoch [990/1000], Loss: 4.87031186\n",
      "TE: Y->X 0.1243\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.26214819\n",
      "Epoch [990/1000], Loss: 6.25382231\n",
      "TE: X->Y 0.0079\n",
      "Epoch [990/1000], Loss: 6.40748125\n",
      "Epoch [990/1000], Loss: 6.27086611\n",
      "TE: Y->X 0.1364\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.29138175\n",
      "Epoch [990/1000], Loss: 6.29078933\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 6.41211005\n",
      "Epoch [990/1000], Loss: 6.28740297\n",
      "TE: Y->X 0.1247\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.67825506\n",
      "Epoch [990/1000], Loss: 7.66582405\n",
      "TE: X->Y 0.0118\n",
      "Epoch [990/1000], Loss: 7.81888844\n",
      "Epoch [990/1000], Loss: 7.68038541\n",
      "TE: Y->X 0.1381\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.70530975\n",
      "Epoch [990/1000], Loss: 7.70478326\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 7.83269791\n",
      "Epoch [990/1000], Loss: 7.70803172\n",
      "TE: Y->X 0.1247\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.10331814\n",
      "Epoch [990/1000], Loss: 9.09188868\n",
      "TE: X->Y 0.0108\n",
      "Epoch [990/1000], Loss: 9.22316927\n",
      "Epoch [990/1000], Loss: 9.07730646\n",
      "TE: Y->X 0.1452\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.12276591\n",
      "Epoch [990/1000], Loss: 9.12205839\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 9.24938299\n",
      "Epoch [990/1000], Loss: 9.12417317\n",
      "TE: Y->X 0.1252\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 10.49592072\n",
      "Epoch [990/1000], Loss: 10.47548991\n",
      "TE: X->Y 0.0191\n",
      "Epoch [990/1000], Loss: 10.64079861\n",
      "Epoch [990/1000], Loss: 10.48069519\n",
      "TE: Y->X 0.1588\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 10.53709207\n",
      "Epoch [990/1000], Loss: 10.53604901\n",
      "TE: X->Y 0.001\n",
      "Epoch [990/1000], Loss: 10.66780413\n",
      "Epoch [990/1000], Loss: 10.54249544\n",
      "TE: Y->X 0.1253\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.89953238\n",
      "Epoch [990/1000], Loss: 11.86325813\n",
      "TE: X->Y 0.0343\n",
      "Epoch [990/1000], Loss: 12.04257583\n",
      "Epoch [990/1000], Loss: 11.88044792\n",
      "TE: Y->X 0.1606\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.95124441\n",
      "Epoch [990/1000], Loss: 11.95017977\n",
      "TE: X->Y 0.0011\n",
      "Epoch [990/1000], Loss: 12.09049386\n",
      "Epoch [990/1000], Loss: 11.96475933\n",
      "TE: Y->X 0.1257\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.31365978\n",
      "Epoch [990/1000], Loss: 13.27961007\n",
      "TE: X->Y 0.031\n",
      "Epoch [990/1000], Loss: 13.45891926\n",
      "Epoch [990/1000], Loss: 13.26191133\n",
      "TE: Y->X 0.1932\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 13.37182009\n",
      "Epoch [990/1000], Loss: 13.37028975\n",
      "TE: X->Y 0.0015\n",
      "Epoch [990/1000], Loss: 13.50732553\n",
      "Epoch [990/1000], Loss: 13.38125202\n",
      "TE: Y->X 0.126\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61100639\n",
      "Epoch [990/1000], Loss: 0.61071076\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.74024059\n",
      "Epoch [990/1000], Loss: 0.62352657\n",
      "TE: Y->X 0.1167\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.60913972\n",
      "Epoch [990/1000], Loss: 0.60906371\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.74194503\n",
      "Epoch [990/1000], Loss: 0.61561115\n",
      "TE: Y->X 0.1263\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.03145167\n",
      "Epoch [990/1000], Loss: 2.03053696\n",
      "TE: X->Y 0.0009\n",
      "Epoch [990/1000], Loss: 2.15414201\n",
      "Epoch [990/1000], Loss: 2.03622061\n",
      "TE: Y->X 0.1179\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.02822554\n",
      "Epoch [990/1000], Loss: 2.02823949\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 2.16380314\n",
      "Epoch [990/1000], Loss: 2.03751645\n",
      "TE: Y->X 0.1263\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.45518934\n",
      "Epoch [990/1000], Loss: 3.45178429\n",
      "TE: X->Y 0.0033\n",
      "Epoch [990/1000], Loss: 3.57438072\n",
      "Epoch [990/1000], Loss: 3.45501101\n",
      "TE: Y->X 0.1193\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.44381572\n",
      "Epoch [990/1000], Loss: 3.44371679\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 3.58278131\n",
      "Epoch [990/1000], Loss: 3.45639575\n",
      "TE: Y->X 0.1264\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.85679446\n",
      "Epoch [990/1000], Loss: 4.85088878\n",
      "TE: X->Y 0.0057\n",
      "Epoch [990/1000], Loss: 4.99734485\n",
      "Epoch [990/1000], Loss: 4.87494586\n",
      "TE: Y->X 0.1223\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.86643453\n",
      "Epoch [990/1000], Loss: 4.86620656\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 5.00046261\n",
      "Epoch [990/1000], Loss: 4.87396944\n",
      "TE: Y->X 0.1265\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.25279558\n",
      "Epoch [990/1000], Loss: 6.24626023\n",
      "TE: X->Y 0.0063\n",
      "Epoch [990/1000], Loss: 6.41641397\n",
      "Epoch [990/1000], Loss: 6.29009745\n",
      "TE: Y->X 0.126\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.28393607\n",
      "Epoch [990/1000], Loss: 6.28358598\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 6.41764808\n",
      "Epoch [990/1000], Loss: 6.29082045\n",
      "TE: Y->X 0.1268\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.67339308\n",
      "Epoch [990/1000], Loss: 7.66362174\n",
      "TE: X->Y 0.0091\n",
      "Epoch [990/1000], Loss: 7.82789776\n",
      "Epoch [990/1000], Loss: 7.69644826\n",
      "TE: Y->X 0.131\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.69459698\n",
      "Epoch [990/1000], Loss: 7.69397525\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 7.83819737\n",
      "Epoch [990/1000], Loss: 7.71088779\n",
      "TE: Y->X 0.1273\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.08294228\n",
      "Epoch [990/1000], Loss: 9.06951185\n",
      "TE: X->Y 0.0125\n",
      "Epoch [990/1000], Loss: 9.23253846\n",
      "Epoch [990/1000], Loss: 9.10171103\n",
      "TE: Y->X 0.1304\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.11229978\n",
      "Epoch [990/1000], Loss: 9.11160509\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 9.25764099\n",
      "Epoch [990/1000], Loss: 9.13032959\n",
      "TE: Y->X 0.1273\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 10.50359409\n",
      "Epoch [990/1000], Loss: 10.48115469\n",
      "TE: X->Y 0.0212\n",
      "Epoch [990/1000], Loss: 10.64037754\n",
      "Epoch [990/1000], Loss: 10.49465803\n",
      "TE: Y->X 0.1447\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 10.53746888\n",
      "Epoch [990/1000], Loss: 10.53648965\n",
      "TE: X->Y 0.001\n",
      "Epoch [990/1000], Loss: 10.67518059\n",
      "Epoch [990/1000], Loss: 10.54810555\n",
      "TE: Y->X 0.1271\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.91374249\n",
      "Epoch [990/1000], Loss: 11.87134643\n",
      "TE: X->Y 0.0398\n",
      "Epoch [990/1000], Loss: 12.04354411\n",
      "Epoch [990/1000], Loss: 11.88793301\n",
      "TE: Y->X 0.1537\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.95469082\n",
      "Epoch [990/1000], Loss: 11.95371863\n",
      "TE: X->Y 0.001\n",
      "Epoch [990/1000], Loss: 12.09178832\n",
      "Epoch [990/1000], Loss: 11.96367902\n",
      "TE: Y->X 0.1281\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.31455083\n",
      "Epoch [990/1000], Loss: 13.26886955\n",
      "TE: X->Y 0.0425\n",
      "Epoch [990/1000], Loss: 13.44320743\n",
      "Epoch [990/1000], Loss: 13.27251332\n",
      "TE: Y->X 0.1683\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 13.37860431\n",
      "Epoch [990/1000], Loss: 13.37754327\n",
      "TE: X->Y 0.0011\n",
      "Epoch [990/1000], Loss: 13.50374979\n",
      "Epoch [990/1000], Loss: 13.37533699\n",
      "TE: Y->X 0.1284\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.60912529\n",
      "Epoch [990/1000], Loss: 0.60904096\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.75244499\n",
      "Epoch [990/1000], Loss: 0.62662026\n",
      "TE: Y->X 0.1258\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61192858\n",
      "Epoch [990/1000], Loss: 0.61183867\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.73916297\n",
      "Epoch [990/1000], Loss: 0.61306856\n",
      "TE: Y->X 0.1261\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.02467946\n",
      "Epoch [990/1000], Loss: 2.02354172\n",
      "TE: X->Y 0.0011\n",
      "Epoch [990/1000], Loss: 2.17320818\n",
      "Epoch [990/1000], Loss: 2.04720678\n",
      "TE: Y->X 0.126\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.02688472\n",
      "Epoch [990/1000], Loss: 2.02690792\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 2.15712718\n",
      "Epoch [990/1000], Loss: 2.03099771\n",
      "TE: Y->X 0.1261\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.44091181\n",
      "Epoch [990/1000], Loss: 3.43952249\n",
      "TE: X->Y 0.0014\n",
      "Epoch [990/1000], Loss: 3.58937137\n",
      "Epoch [990/1000], Loss: 3.46130646\n",
      "TE: Y->X 0.1281\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.44569864\n",
      "Epoch [990/1000], Loss: 3.44567768\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 3.57234863\n",
      "Epoch [990/1000], Loss: 3.44598008\n",
      "TE: Y->X 0.1264\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.87578779\n",
      "Epoch [990/1000], Loss: 4.87278324\n",
      "TE: X->Y 0.0029\n",
      "Epoch [990/1000], Loss: 5.00393212\n",
      "Epoch [990/1000], Loss: 4.87515094\n",
      "TE: Y->X 0.1287\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.87269634\n",
      "Epoch [990/1000], Loss: 4.87236166\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 4.98818818\n",
      "Epoch [990/1000], Loss: 4.86192214\n",
      "TE: Y->X 0.1263\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.29606629\n",
      "Epoch [990/1000], Loss: 6.28814759\n",
      "TE: X->Y 0.0076\n",
      "Epoch [990/1000], Loss: 6.41906375\n",
      "Epoch [990/1000], Loss: 6.28447789\n",
      "TE: Y->X 0.1343\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.29193858\n",
      "Epoch [990/1000], Loss: 6.29158922\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 6.40629309\n",
      "Epoch [990/1000], Loss: 6.27956375\n",
      "TE: Y->X 0.1267\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.70160464\n",
      "Epoch [990/1000], Loss: 7.69510828\n",
      "TE: X->Y 0.0061\n",
      "Epoch [990/1000], Loss: 7.84511732\n",
      "Epoch [990/1000], Loss: 7.70457176\n",
      "TE: Y->X 0.1401\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.71345508\n",
      "Epoch [990/1000], Loss: 7.71287809\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 7.82681149\n",
      "Epoch [990/1000], Loss: 7.69992512\n",
      "TE: Y->X 0.1269\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.09862268\n",
      "Epoch [990/1000], Loss: 9.07846028\n",
      "TE: X->Y 0.0189\n",
      "Epoch [990/1000], Loss: 9.26260978\n",
      "Epoch [990/1000], Loss: 9.11914245\n",
      "TE: Y->X 0.1429\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.12342329\n",
      "Epoch [990/1000], Loss: 9.12277308\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 9.24873209\n",
      "Epoch [990/1000], Loss: 9.12174788\n",
      "TE: Y->X 0.127\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 10.47968082\n",
      "Epoch [990/1000], Loss: 10.45283066\n",
      "TE: X->Y 0.0256\n",
      "Epoch [990/1000], Loss: 10.67365359\n",
      "Epoch [990/1000], Loss: 10.51754382\n",
      "TE: Y->X 0.1551\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 10.53905887\n",
      "Epoch [990/1000], Loss: 10.53836822\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 10.66652555\n",
      "Epoch [990/1000], Loss: 10.53943796\n",
      "TE: Y->X 0.1271\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.89069943\n",
      "Epoch [990/1000], Loss: 11.86277436\n",
      "TE: X->Y 0.0259\n",
      "Epoch [990/1000], Loss: 12.07982246\n",
      "Epoch [990/1000], Loss: 11.90916528\n",
      "TE: Y->X 0.1692\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.96094858\n",
      "Epoch [990/1000], Loss: 11.95960488\n",
      "TE: X->Y 0.0013\n",
      "Epoch [990/1000], Loss: 12.08540665\n",
      "Epoch [990/1000], Loss: 11.95781591\n",
      "TE: Y->X 0.1276\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.27246461\n",
      "Epoch [990/1000], Loss: 13.23570039\n",
      "TE: X->Y 0.0341\n",
      "Epoch [990/1000], Loss: 13.49783611\n",
      "Epoch [990/1000], Loss: 13.31227746\n",
      "TE: Y->X 0.1826\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 13.37681816\n",
      "Epoch [990/1000], Loss: 13.37545074\n",
      "TE: X->Y 0.0014\n",
      "Epoch [990/1000], Loss: 13.50462197\n",
      "Epoch [990/1000], Loss: 13.37643794\n",
      "TE: Y->X 0.1281\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61943826\n",
      "Epoch [990/1000], Loss: 0.61917666\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.72144256\n",
      "Epoch [990/1000], Loss: 0.60059212\n",
      "TE: Y->X 0.1208\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61098643\n",
      "Epoch [990/1000], Loss: 0.61097055\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73678941\n",
      "Epoch [990/1000], Loss: 0.61090818\n",
      "TE: Y->X 0.1259\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.04523392\n",
      "Epoch [990/1000], Loss: 2.04351338\n",
      "TE: X->Y 0.0017\n",
      "Epoch [990/1000], Loss: 2.13286162\n",
      "Epoch [990/1000], Loss: 2.00988809\n",
      "TE: Y->X 0.123\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.02936947\n",
      "Epoch [990/1000], Loss: 2.02929057\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 2.15586026\n",
      "Epoch [990/1000], Loss: 2.03009825\n",
      "TE: Y->X 0.1258\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.44972204\n",
      "Epoch [990/1000], Loss: 3.44693107\n",
      "TE: X->Y 0.0027\n",
      "Epoch [990/1000], Loss: 3.55773909\n",
      "Epoch [990/1000], Loss: 3.43361877\n",
      "TE: Y->X 0.1241\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.44900186\n",
      "Epoch [990/1000], Loss: 3.44886255\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 3.57424811\n",
      "Epoch [990/1000], Loss: 3.44838032\n",
      "TE: Y->X 0.1259\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.85604175\n",
      "Epoch [990/1000], Loss: 4.85249416\n",
      "TE: X->Y 0.0034\n",
      "Epoch [990/1000], Loss: 4.97262084\n",
      "Epoch [990/1000], Loss: 4.84468312\n",
      "TE: Y->X 0.1278\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.86185514\n",
      "Epoch [990/1000], Loss: 4.86171428\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 4.99638463\n",
      "Epoch [990/1000], Loss: 4.87048309\n",
      "TE: Y->X 0.1259\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.27375478\n",
      "Epoch [990/1000], Loss: 6.26231784\n",
      "TE: X->Y 0.011\n",
      "Epoch [990/1000], Loss: 6.38560903\n",
      "Epoch [990/1000], Loss: 6.25735124\n",
      "TE: Y->X 0.1281\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.27633049\n",
      "Epoch [990/1000], Loss: 6.27601366\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 6.41202151\n",
      "Epoch [990/1000], Loss: 6.28630253\n",
      "TE: Y->X 0.1257\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.67305285\n",
      "Epoch [990/1000], Loss: 7.66245615\n",
      "TE: X->Y 0.0101\n",
      "Epoch [990/1000], Loss: 7.79701193\n",
      "Epoch [990/1000], Loss: 7.66465321\n",
      "TE: Y->X 0.1321\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.69220432\n",
      "Epoch [990/1000], Loss: 7.69162975\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 7.83181334\n",
      "Epoch [990/1000], Loss: 7.70557477\n",
      "TE: Y->X 0.1262\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.09263842\n",
      "Epoch [990/1000], Loss: 9.07790125\n",
      "TE: X->Y 0.014\n",
      "Epoch [990/1000], Loss: 9.20337099\n",
      "Epoch [990/1000], Loss: 9.05796592\n",
      "TE: Y->X 0.1445\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.11153986\n",
      "Epoch [990/1000], Loss: 9.11087467\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 9.24707579\n",
      "Epoch [990/1000], Loss: 9.12030158\n",
      "TE: Y->X 0.1268\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 10.47846531\n",
      "Epoch [990/1000], Loss: 10.44939399\n",
      "TE: X->Y 0.0273\n",
      "Epoch [990/1000], Loss: 10.61295147\n",
      "Epoch [990/1000], Loss: 10.46128119\n",
      "TE: Y->X 0.1509\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 10.53285755\n",
      "Epoch [990/1000], Loss: 10.53187428\n",
      "TE: X->Y 0.001\n",
      "Epoch [990/1000], Loss: 10.66405239\n",
      "Epoch [990/1000], Loss: 10.53709627\n",
      "TE: Y->X 0.1269\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.89547149\n",
      "Epoch [990/1000], Loss: 11.86144603\n",
      "TE: X->Y 0.0316\n",
      "Epoch [990/1000], Loss: 12.02127595\n",
      "Epoch [990/1000], Loss: 11.85585619\n",
      "TE: Y->X 0.1636\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.95302524\n",
      "Epoch [990/1000], Loss: 11.95227138\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 12.07724724\n",
      "Epoch [990/1000], Loss: 11.95004633\n",
      "TE: Y->X 0.1272\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.31247928\n",
      "Epoch [990/1000], Loss: 13.28108071\n",
      "TE: X->Y 0.0288\n",
      "Epoch [990/1000], Loss: 13.43143763\n",
      "Epoch [990/1000], Loss: 13.25132225\n",
      "TE: Y->X 0.1773\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 13.37519588\n",
      "Epoch [990/1000], Loss: 13.37377499\n",
      "TE: X->Y 0.0014\n",
      "Epoch [990/1000], Loss: 13.49610117\n",
      "Epoch [990/1000], Loss: 13.36852026\n",
      "TE: Y->X 0.1275\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61320362\n",
      "Epoch [990/1000], Loss: 0.61330719\n",
      "TE: X->Y -0.0001\n",
      "Epoch [990/1000], Loss: 0.73620715\n",
      "Epoch [990/1000], Loss: 0.60395555\n",
      "TE: Y->X 0.1322\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61032295\n",
      "Epoch [990/1000], Loss: 0.61032181\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.74391941\n",
      "Epoch [990/1000], Loss: 0.61723606\n",
      "TE: Y->X 0.1267\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.03564334\n",
      "Epoch [990/1000], Loss: 2.03516194\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 2.14876433\n",
      "Epoch [990/1000], Loss: 2.01654572\n",
      "TE: Y->X 0.1322\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.02985701\n",
      "Epoch [990/1000], Loss: 2.02986378\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 2.16153423\n",
      "Epoch [990/1000], Loss: 2.03484095\n",
      "TE: Y->X 0.1267\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.44695179\n",
      "Epoch [990/1000], Loss: 3.44501636\n",
      "TE: X->Y 0.0019\n",
      "Epoch [990/1000], Loss: 3.57002086\n",
      "Epoch [990/1000], Loss: 3.43553531\n",
      "TE: Y->X 0.1345\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.44481277\n",
      "Epoch [990/1000], Loss: 3.44459171\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 3.58109008\n",
      "Epoch [990/1000], Loss: 3.45427641\n",
      "TE: Y->X 0.1268\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.85566618\n",
      "Epoch [990/1000], Loss: 4.84976538\n",
      "TE: X->Y 0.0057\n",
      "Epoch [990/1000], Loss: 4.98944674\n",
      "Epoch [990/1000], Loss: 4.85247436\n",
      "TE: Y->X 0.1368\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.86853608\n",
      "Epoch [990/1000], Loss: 4.86826206\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 4.99618032\n",
      "Epoch [990/1000], Loss: 4.86942181\n",
      "TE: Y->X 0.1268\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.27065856\n",
      "Epoch [990/1000], Loss: 6.25984229\n",
      "TE: X->Y 0.0104\n",
      "Epoch [990/1000], Loss: 6.40158009\n",
      "Epoch [990/1000], Loss: 6.26225485\n",
      "TE: Y->X 0.1391\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.28880119\n",
      "Epoch [990/1000], Loss: 6.28848674\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 6.41565728\n",
      "Epoch [990/1000], Loss: 6.28844733\n",
      "TE: Y->X 0.1272\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.70171965\n",
      "Epoch [990/1000], Loss: 7.69133801\n",
      "TE: X->Y 0.0097\n",
      "Epoch [990/1000], Loss: 7.82002446\n",
      "Epoch [990/1000], Loss: 7.67617684\n",
      "TE: Y->X 0.1435\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.70603467\n",
      "Epoch [990/1000], Loss: 7.70556203\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 7.83502394\n",
      "Epoch [990/1000], Loss: 7.70770652\n",
      "TE: Y->X 0.1273\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.13590791\n",
      "Epoch [990/1000], Loss: 9.12236186\n",
      "TE: X->Y 0.0127\n",
      "Epoch [990/1000], Loss: 9.21969889\n",
      "Epoch [990/1000], Loss: 9.06937054\n",
      "TE: Y->X 0.1498\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.12583775\n",
      "Epoch [990/1000], Loss: 9.12529418\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 9.25436553\n",
      "Epoch [990/1000], Loss: 9.12664224\n",
      "TE: Y->X 0.1277\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 10.55710177\n",
      "Epoch [990/1000], Loss: 10.53732831\n",
      "TE: X->Y 0.0183\n",
      "Epoch [990/1000], Loss: 10.63025045\n",
      "Epoch [990/1000], Loss: 10.46655216\n",
      "TE: Y->X 0.1624\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 10.54452993\n",
      "Epoch [990/1000], Loss: 10.54349916\n",
      "TE: X->Y 0.001\n",
      "Epoch [990/1000], Loss: 10.67278249\n",
      "Epoch [990/1000], Loss: 10.54487784\n",
      "TE: Y->X 0.1279\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.96243599\n",
      "Epoch [990/1000], Loss: 11.91935617\n",
      "TE: X->Y 0.0401\n",
      "Epoch [990/1000], Loss: 12.04282612\n",
      "Epoch [990/1000], Loss: 11.87721283\n",
      "TE: Y->X 0.1637\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.96457181\n",
      "Epoch [990/1000], Loss: 11.96338493\n",
      "TE: X->Y 0.0012\n",
      "Epoch [990/1000], Loss: 12.09362252\n",
      "Epoch [990/1000], Loss: 11.96534923\n",
      "TE: Y->X 0.1282\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.35490184\n",
      "Epoch [990/1000], Loss: 13.31102603\n",
      "TE: X->Y 0.041\n",
      "Epoch [990/1000], Loss: 13.45506116\n",
      "Epoch [990/1000], Loss: 13.27336576\n",
      "TE: Y->X 0.1795\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 13.38131941\n",
      "Epoch [990/1000], Loss: 13.38015396\n",
      "TE: X->Y 0.0012\n",
      "Epoch [990/1000], Loss: 13.51305046\n",
      "Epoch [990/1000], Loss: 13.38465405\n",
      "TE: Y->X 0.1284\n"
     ]
    }
   ],
   "source": [
    "lg_results_TE_X2Y = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "lg_results_TE_Y2X = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for dim, generator in zip(dim_range, lg_generator_lst):\n",
    "        print(\"## Dim = \", dim, \"#\")\n",
    "        for samples in sample_sizes:\n",
    "            print(\"# Sample size = \", samples, \"#\")\n",
    "            # Simulate data\n",
    "            dataset = get_dataset(generator, samples, int(np.round(samples/20)), seed=r)\n",
    "            # Estimate X -> Y\n",
    "            TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', dim*16, NB)\n",
    "            lg_results_TE_X2Y.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_X2Y)\n",
    "            # Estimate Y -> X\n",
    "            TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', dim*16, NB)\n",
    "            lg_results_TE_Y2X.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_Y2X)\n",
    "\n",
    "lg_results_TE_X2Y.df.to_csv('results/agmte/lg_results_TE_X2Y_dimred.csv', index=False)\n",
    "lg_results_TE_Y2X.df.to_csv('results/agmte/lg_results_TE_Y2X_dimred.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Joint Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the list of generators with one for each dimension\n",
    "jp_generator_lst = [MVJointProcessSimulator(lam = 0.0, n_dim=dim, n_redundant_dim=dim-1) for dim in dim_range]\n",
    "# Get the reference values\n",
    "jp_TE_X2Y_ref_lst = [generator.analytic_transfer_entropy('X', 'Y') for generator in jp_generator_lst]\n",
    "jp_TE_Y2X_ref_lst = [generator.analytic_transfer_entropy('Y', 'X') for generator in jp_generator_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.40366864\n",
      "Epoch [990/1000], Loss: 1.00495748\n",
      "TE: X->Y 0.3984\n",
      "Epoch [990/1000], Loss: 1.41204535\n",
      "Epoch [990/1000], Loss: 1.41186869\n",
      "TE: Y->X 0.0002\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41731446\n",
      "Epoch [990/1000], Loss: 1.01335655\n",
      "TE: X->Y 0.4037\n",
      "Epoch [990/1000], Loss: 1.41761742\n",
      "Epoch [990/1000], Loss: 1.41759189\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.83030521\n",
      "Epoch [990/1000], Loss: 2.42918275\n",
      "TE: X->Y 0.4005\n",
      "Epoch [990/1000], Loss: 2.82681051\n",
      "Epoch [990/1000], Loss: 2.82666257\n",
      "TE: Y->X 0.0001\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83480878\n",
      "Epoch [990/1000], Loss: 2.42916013\n",
      "TE: X->Y 0.4053\n",
      "Epoch [990/1000], Loss: 2.83701331\n",
      "Epoch [990/1000], Loss: 2.83696815\n",
      "TE: Y->X 0.0\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.23281361\n",
      "Epoch [990/1000], Loss: 3.83241024\n",
      "TE: X->Y 0.3998\n",
      "Epoch [990/1000], Loss: 4.25172935\n",
      "Epoch [990/1000], Loss: 4.25020096\n",
      "TE: Y->X 0.0015\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.25540826\n",
      "Epoch [990/1000], Loss: 3.85009974\n",
      "TE: X->Y 0.405\n",
      "Epoch [990/1000], Loss: 4.25447258\n",
      "Epoch [990/1000], Loss: 4.25432598\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.64146162\n",
      "Epoch [990/1000], Loss: 5.22099415\n",
      "TE: X->Y 0.4189\n",
      "Epoch [990/1000], Loss: 5.66460194\n",
      "Epoch [990/1000], Loss: 5.66048472\n",
      "TE: Y->X 0.0038\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.67210103\n",
      "Epoch [990/1000], Loss: 5.26582014\n",
      "TE: X->Y 0.4059\n",
      "Epoch [990/1000], Loss: 5.67507472\n",
      "Epoch [990/1000], Loss: 5.67491187\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.05895184\n",
      "Epoch [990/1000], Loss: 6.61435417\n",
      "TE: X->Y 0.441\n",
      "Epoch [990/1000], Loss: 7.07396621\n",
      "Epoch [990/1000], Loss: 7.06756467\n",
      "TE: Y->X 0.0052\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.09615267\n",
      "Epoch [990/1000], Loss: 6.69058916\n",
      "TE: X->Y 0.4051\n",
      "Epoch [990/1000], Loss: 7.09225144\n",
      "Epoch [990/1000], Loss: 7.09200953\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.47392228\n",
      "Epoch [990/1000], Loss: 8.00361722\n",
      "TE: X->Y 0.4651\n",
      "Epoch [990/1000], Loss: 8.47345218\n",
      "Epoch [990/1000], Loss: 8.45453629\n",
      "TE: Y->X 0.0156\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.51014786\n",
      "Epoch [990/1000], Loss: 8.10699114\n",
      "TE: X->Y 0.4025\n",
      "Epoch [990/1000], Loss: 8.51301662\n",
      "Epoch [990/1000], Loss: 8.51260615\n",
      "TE: Y->X 0.0004\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.87275959\n",
      "Epoch [990/1000], Loss: 9.36650301\n",
      "TE: X->Y 0.4965\n",
      "Epoch [990/1000], Loss: 9.87475049\n",
      "Epoch [990/1000], Loss: 9.84665448\n",
      "TE: Y->X 0.0228\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.92749004\n",
      "Epoch [990/1000], Loss: 9.52166607\n",
      "TE: X->Y 0.4053\n",
      "Epoch [990/1000], Loss: 9.92932329\n",
      "Epoch [990/1000], Loss: 9.92865111\n",
      "TE: Y->X 0.0007\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.24098505\n",
      "Epoch [990/1000], Loss: 10.66175846\n",
      "TE: X->Y 0.5617\n",
      "Epoch [990/1000], Loss: 11.26041158\n",
      "Epoch [990/1000], Loss: 11.22798376\n",
      "TE: Y->X 0.0248\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.34191469\n",
      "Epoch [990/1000], Loss: 10.93556985\n",
      "TE: X->Y 0.4058\n",
      "Epoch [990/1000], Loss: 11.34712124\n",
      "Epoch [990/1000], Loss: 11.34664298\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.60308784\n",
      "Epoch [990/1000], Loss: 11.92969075\n",
      "TE: X->Y 0.6471\n",
      "Epoch [990/1000], Loss: 12.63720006\n",
      "Epoch [990/1000], Loss: 12.51449063\n",
      "TE: Y->X 0.1057\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.75539753\n",
      "Epoch [990/1000], Loss: 12.34891588\n",
      "TE: X->Y 0.4058\n",
      "Epoch [990/1000], Loss: 12.76994005\n",
      "Epoch [990/1000], Loss: 12.76932794\n",
      "TE: Y->X 0.0006\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 14.00072636\n",
      "Epoch [990/1000], Loss: 13.15892716\n",
      "TE: X->Y 0.7953\n",
      "Epoch [990/1000], Loss: 13.99178291\n",
      "Epoch [990/1000], Loss: 13.86092856\n",
      "TE: Y->X 0.1055\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.17574244\n",
      "Epoch [990/1000], Loss: 13.76865777\n",
      "TE: X->Y 0.4064\n",
      "Epoch [990/1000], Loss: 14.18634683\n",
      "Epoch [990/1000], Loss: 14.18543886\n",
      "TE: Y->X 0.0009\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.42549829\n",
      "Epoch [990/1000], Loss: 1.00851152\n",
      "TE: X->Y 0.4167\n",
      "Epoch [990/1000], Loss: 1.41976389\n",
      "Epoch [990/1000], Loss: 1.41956116\n",
      "TE: Y->X 0.0002\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41705711\n",
      "Epoch [990/1000], Loss: 1.01289049\n",
      "TE: X->Y 0.4039\n",
      "Epoch [990/1000], Loss: 1.41591503\n",
      "Epoch [990/1000], Loss: 1.41587605\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.84593486\n",
      "Epoch [990/1000], Loss: 2.42583859\n",
      "TE: X->Y 0.4197\n",
      "Epoch [990/1000], Loss: 2.83361779\n",
      "Epoch [990/1000], Loss: 2.83328253\n",
      "TE: Y->X 0.0003\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83607291\n",
      "Epoch [990/1000], Loss: 2.43323019\n",
      "TE: X->Y 0.4024\n",
      "Epoch [990/1000], Loss: 2.83771032\n",
      "Epoch [990/1000], Loss: 2.83762288\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.26881434\n",
      "Epoch [990/1000], Loss: 3.84346199\n",
      "TE: X->Y 0.4246\n",
      "Epoch [990/1000], Loss: 4.25262004\n",
      "Epoch [990/1000], Loss: 4.24925575\n",
      "TE: Y->X 0.0031\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.25186172\n",
      "Epoch [990/1000], Loss: 3.84696354\n",
      "TE: X->Y 0.4046\n",
      "Epoch [990/1000], Loss: 4.25655592\n",
      "Epoch [990/1000], Loss: 4.25651769\n",
      "TE: Y->X 0.0\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.66759712\n",
      "Epoch [990/1000], Loss: 5.23234892\n",
      "TE: X->Y 0.4338\n",
      "Epoch [990/1000], Loss: 5.67422922\n",
      "Epoch [990/1000], Loss: 5.67136287\n",
      "TE: Y->X 0.0025\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.67414073\n",
      "Epoch [990/1000], Loss: 5.27061008\n",
      "TE: X->Y 0.4031\n",
      "Epoch [990/1000], Loss: 5.67420513\n",
      "Epoch [990/1000], Loss: 5.67390915\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.06005959\n",
      "Epoch [990/1000], Loss: 6.60925261\n",
      "TE: X->Y 0.4485\n",
      "Epoch [990/1000], Loss: 7.08762117\n",
      "Epoch [990/1000], Loss: 7.07968782\n",
      "TE: Y->X 0.0066\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.09141572\n",
      "Epoch [990/1000], Loss: 6.68983089\n",
      "TE: X->Y 0.401\n",
      "Epoch [990/1000], Loss: 7.09126134\n",
      "Epoch [990/1000], Loss: 7.09111022\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.47232533\n",
      "Epoch [990/1000], Loss: 7.99481536\n",
      "TE: X->Y 0.4719\n",
      "Epoch [990/1000], Loss: 8.49747433\n",
      "Epoch [990/1000], Loss: 8.48802962\n",
      "TE: Y->X 0.0077\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.50212829\n",
      "Epoch [990/1000], Loss: 8.10038132\n",
      "TE: X->Y 0.4012\n",
      "Epoch [990/1000], Loss: 8.51177555\n",
      "Epoch [990/1000], Loss: 8.51142729\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.87080403\n",
      "Epoch [990/1000], Loss: 9.32471118\n",
      "TE: X->Y 0.5342\n",
      "Epoch [990/1000], Loss: 9.88652199\n",
      "Epoch [990/1000], Loss: 9.86434933\n",
      "TE: Y->X 0.018\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.91981149\n",
      "Epoch [990/1000], Loss: 9.51667633\n",
      "TE: X->Y 0.4025\n",
      "Epoch [990/1000], Loss: 9.93100346\n",
      "Epoch [990/1000], Loss: 9.93064235\n",
      "TE: Y->X 0.0004\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.26957276\n",
      "Epoch [990/1000], Loss: 10.65244041\n",
      "TE: X->Y 0.6\n",
      "Epoch [990/1000], Loss: 11.26465608\n",
      "Epoch [990/1000], Loss: 11.23456508\n",
      "TE: Y->X 0.0232\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.34462954\n",
      "Epoch [990/1000], Loss: 10.93999096\n",
      "TE: X->Y 0.4039\n",
      "Epoch [990/1000], Loss: 11.34809934\n",
      "Epoch [990/1000], Loss: 11.34812671\n",
      "TE: Y->X 0.0\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.64168016\n",
      "Epoch [990/1000], Loss: 11.93844897\n",
      "TE: X->Y 0.679\n",
      "Epoch [990/1000], Loss: 12.64728529\n",
      "Epoch [990/1000], Loss: 12.51194748\n",
      "TE: Y->X 0.1186\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.76169127\n",
      "Epoch [990/1000], Loss: 12.35522705\n",
      "TE: X->Y 0.4059\n",
      "Epoch [990/1000], Loss: 12.76492457\n",
      "Epoch [990/1000], Loss: 12.76410385\n",
      "TE: Y->X 0.0008\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 14.00306768\n",
      "Epoch [990/1000], Loss: 13.17417198\n",
      "TE: X->Y 0.7863\n",
      "Epoch [990/1000], Loss: 14.01956936\n",
      "Epoch [990/1000], Loss: 13.79391118\n",
      "TE: Y->X 0.1964\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.18582479\n",
      "Epoch [990/1000], Loss: 13.78275915\n",
      "TE: X->Y 0.4022\n",
      "Epoch [990/1000], Loss: 14.17656011\n",
      "Epoch [990/1000], Loss: 14.17548166\n",
      "TE: Y->X 0.0011\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.42632504\n",
      "Epoch [990/1000], Loss: 1.01725904\n",
      "TE: X->Y 0.4086\n",
      "Epoch [990/1000], Loss: 1.42693084\n",
      "Epoch [990/1000], Loss: 1.42675081\n",
      "TE: Y->X 0.0002\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41403699\n",
      "Epoch [990/1000], Loss: 1.01908297\n",
      "TE: X->Y 0.3947\n",
      "Epoch [990/1000], Loss: 1.41313519\n",
      "Epoch [990/1000], Loss: 1.41312221\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.83961549\n",
      "Epoch [990/1000], Loss: 2.42919572\n",
      "TE: X->Y 0.4098\n",
      "Epoch [990/1000], Loss: 2.84740848\n",
      "Epoch [990/1000], Loss: 2.84626697\n",
      "TE: Y->X 0.0011\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.82911765\n",
      "Epoch [990/1000], Loss: 2.43609509\n",
      "TE: X->Y 0.3926\n",
      "Epoch [990/1000], Loss: 2.83080095\n",
      "Epoch [990/1000], Loss: 2.83071894\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.25700051\n",
      "Epoch [990/1000], Loss: 3.84184876\n",
      "TE: X->Y 0.4143\n",
      "Epoch [990/1000], Loss: 4.26020877\n",
      "Epoch [990/1000], Loss: 4.25875619\n",
      "TE: Y->X 0.0014\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.24742851\n",
      "Epoch [990/1000], Loss: 3.85436772\n",
      "TE: X->Y 0.3926\n",
      "Epoch [990/1000], Loss: 4.24589293\n",
      "Epoch [990/1000], Loss: 4.24576882\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.68746422\n",
      "Epoch [990/1000], Loss: 5.25716153\n",
      "TE: X->Y 0.4289\n",
      "Epoch [990/1000], Loss: 5.66971647\n",
      "Epoch [990/1000], Loss: 5.66825042\n",
      "TE: Y->X 0.0012\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.67496235\n",
      "Epoch [990/1000], Loss: 5.28021391\n",
      "TE: X->Y 0.3943\n",
      "Epoch [990/1000], Loss: 5.66144081\n",
      "Epoch [990/1000], Loss: 5.66123209\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.10590978\n",
      "Epoch [990/1000], Loss: 6.66125402\n",
      "TE: X->Y 0.442\n",
      "Epoch [990/1000], Loss: 7.08278984\n",
      "Epoch [990/1000], Loss: 7.08323268\n",
      "TE: Y->X -0.0007\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.09415036\n",
      "Epoch [990/1000], Loss: 6.70011823\n",
      "TE: X->Y 0.3935\n",
      "Epoch [990/1000], Loss: 7.07913111\n",
      "Epoch [990/1000], Loss: 7.07900081\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.50556867\n",
      "Epoch [990/1000], Loss: 8.03069175\n",
      "TE: X->Y 0.4704\n",
      "Epoch [990/1000], Loss: 8.49790447\n",
      "Epoch [990/1000], Loss: 8.47765569\n",
      "TE: Y->X 0.0167\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.51543496\n",
      "Epoch [990/1000], Loss: 8.12282299\n",
      "TE: X->Y 0.392\n",
      "Epoch [990/1000], Loss: 8.49984149\n",
      "Epoch [990/1000], Loss: 8.49956401\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.89588174\n",
      "Epoch [990/1000], Loss: 9.37866255\n",
      "TE: X->Y 0.5078\n",
      "Epoch [990/1000], Loss: 9.90937636\n",
      "Epoch [990/1000], Loss: 9.88211611\n",
      "TE: Y->X 0.0218\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.92545079\n",
      "Epoch [990/1000], Loss: 9.53143277\n",
      "TE: X->Y 0.3935\n",
      "Epoch [990/1000], Loss: 9.92191296\n",
      "Epoch [990/1000], Loss: 9.92127148\n",
      "TE: Y->X 0.0006\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.24137702\n",
      "Epoch [990/1000], Loss: 10.63072033\n",
      "TE: X->Y 0.594\n",
      "Epoch [990/1000], Loss: 11.30520682\n",
      "Epoch [990/1000], Loss: 11.25748176\n",
      "TE: Y->X 0.0374\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.34074304\n",
      "Epoch [990/1000], Loss: 10.94462769\n",
      "TE: X->Y 0.3955\n",
      "Epoch [990/1000], Loss: 11.33950317\n",
      "Epoch [990/1000], Loss: 11.33906888\n",
      "TE: Y->X 0.0004\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.63682132\n",
      "Epoch [990/1000], Loss: 11.95335927\n",
      "TE: X->Y 0.6585\n",
      "Epoch [990/1000], Loss: 12.69704972\n",
      "Epoch [990/1000], Loss: 12.57737418\n",
      "TE: Y->X 0.1018\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.76209274\n",
      "Epoch [990/1000], Loss: 12.36455556\n",
      "TE: X->Y 0.397\n",
      "Epoch [990/1000], Loss: 12.75841256\n",
      "Epoch [990/1000], Loss: 12.75787416\n",
      "TE: Y->X 0.0006\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.94748838\n",
      "Epoch [990/1000], Loss: 13.10186324\n",
      "TE: X->Y 0.8036\n",
      "Epoch [990/1000], Loss: 14.02135252\n",
      "Epoch [990/1000], Loss: 13.88359143\n",
      "TE: Y->X 0.1136\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.17852679\n",
      "Epoch [990/1000], Loss: 13.78186974\n",
      "TE: X->Y 0.3958\n",
      "Epoch [990/1000], Loss: 14.17788177\n",
      "Epoch [990/1000], Loss: 14.17662796\n",
      "TE: Y->X 0.0013\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.41514817\n",
      "Epoch [990/1000], Loss: 1.01678037\n",
      "TE: X->Y 0.398\n",
      "Epoch [990/1000], Loss: 1.41292658\n",
      "Epoch [990/1000], Loss: 1.41274346\n",
      "TE: Y->X 0.0002\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41542287\n",
      "Epoch [990/1000], Loss: 1.01142283\n",
      "TE: X->Y 0.4038\n",
      "Epoch [990/1000], Loss: 1.41518649\n",
      "Epoch [990/1000], Loss: 1.41516439\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.83959508\n",
      "Epoch [990/1000], Loss: 2.43929616\n",
      "TE: X->Y 0.3999\n",
      "Epoch [990/1000], Loss: 2.82416527\n",
      "Epoch [990/1000], Loss: 2.82407898\n",
      "TE: Y->X 0.0001\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83373765\n",
      "Epoch [990/1000], Loss: 2.43192894\n",
      "TE: X->Y 0.4015\n",
      "Epoch [990/1000], Loss: 2.83399233\n",
      "Epoch [990/1000], Loss: 2.83391644\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.24199862\n",
      "Epoch [990/1000], Loss: 3.83669992\n",
      "TE: X->Y 0.4048\n",
      "Epoch [990/1000], Loss: 4.24762868\n",
      "Epoch [990/1000], Loss: 4.24758809\n",
      "TE: Y->X 0.0\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.25373586\n",
      "Epoch [990/1000], Loss: 3.85384073\n",
      "TE: X->Y 0.3994\n",
      "Epoch [990/1000], Loss: 4.25223659\n",
      "Epoch [990/1000], Loss: 4.25209261\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.64836294\n",
      "Epoch [990/1000], Loss: 5.23205073\n",
      "TE: X->Y 0.415\n",
      "Epoch [990/1000], Loss: 5.66051644\n",
      "Epoch [990/1000], Loss: 5.65886055\n",
      "TE: Y->X 0.0014\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.66611098\n",
      "Epoch [990/1000], Loss: 5.26345681\n",
      "TE: X->Y 0.4023\n",
      "Epoch [990/1000], Loss: 5.67438604\n",
      "Epoch [990/1000], Loss: 5.67423307\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.05814755\n",
      "Epoch [990/1000], Loss: 6.62470012\n",
      "TE: X->Y 0.4304\n",
      "Epoch [990/1000], Loss: 7.06859472\n",
      "Epoch [990/1000], Loss: 7.05943249\n",
      "TE: Y->X 0.008\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.08043952\n",
      "Epoch [990/1000], Loss: 6.67924671\n",
      "TE: X->Y 0.4007\n",
      "Epoch [990/1000], Loss: 7.09025478\n",
      "Epoch [990/1000], Loss: 7.09008664\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.45078008\n",
      "Epoch [990/1000], Loss: 7.99113569\n",
      "TE: X->Y 0.4543\n",
      "Epoch [990/1000], Loss: 8.46489338\n",
      "Epoch [990/1000], Loss: 8.44602285\n",
      "TE: Y->X 0.016\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.49614683\n",
      "Epoch [990/1000], Loss: 8.09293864\n",
      "TE: X->Y 0.4028\n",
      "Epoch [990/1000], Loss: 8.50966508\n",
      "Epoch [990/1000], Loss: 8.50928069\n",
      "TE: Y->X 0.0004\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.85405964\n",
      "Epoch [990/1000], Loss: 9.33580133\n",
      "TE: X->Y 0.5072\n",
      "Epoch [990/1000], Loss: 9.86228773\n",
      "Epoch [990/1000], Loss: 9.81624447\n",
      "TE: Y->X 0.0396\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.91536053\n",
      "Epoch [990/1000], Loss: 9.51554555\n",
      "TE: X->Y 0.3992\n",
      "Epoch [990/1000], Loss: 9.92449636\n",
      "Epoch [990/1000], Loss: 9.92404503\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.21349369\n",
      "Epoch [990/1000], Loss: 10.65533603\n",
      "TE: X->Y 0.5425\n",
      "Epoch [990/1000], Loss: 11.24639209\n",
      "Epoch [990/1000], Loss: 11.17586144\n",
      "TE: Y->X 0.0597\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.33654328\n",
      "Epoch [990/1000], Loss: 10.93383573\n",
      "TE: X->Y 0.4022\n",
      "Epoch [990/1000], Loss: 11.34126924\n",
      "Epoch [990/1000], Loss: 11.34078467\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.61171352\n",
      "Epoch [990/1000], Loss: 11.98208344\n",
      "TE: X->Y 0.6051\n",
      "Epoch [990/1000], Loss: 12.62529816\n",
      "Epoch [990/1000], Loss: 12.50702612\n",
      "TE: Y->X 0.1014\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.75673494\n",
      "Epoch [990/1000], Loss: 12.35223595\n",
      "TE: X->Y 0.4039\n",
      "Epoch [990/1000], Loss: 12.75454336\n",
      "Epoch [990/1000], Loss: 12.75397726\n",
      "TE: Y->X 0.0006\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.97573476\n",
      "Epoch [990/1000], Loss: 13.19395749\n",
      "TE: X->Y 0.7425\n",
      "Epoch [990/1000], Loss: 13.97161643\n",
      "Epoch [990/1000], Loss: 13.82962746\n",
      "TE: Y->X 0.1163\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.17877004\n",
      "Epoch [990/1000], Loss: 13.77428558\n",
      "TE: X->Y 0.4038\n",
      "Epoch [990/1000], Loss: 14.17306757\n",
      "Epoch [990/1000], Loss: 14.17201307\n",
      "TE: Y->X 0.0011\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.41032588\n",
      "Epoch [990/1000], Loss: 1.00269928\n",
      "TE: X->Y 0.4073\n",
      "Epoch [990/1000], Loss: 1.41267188\n",
      "Epoch [990/1000], Loss: 1.41256266\n",
      "TE: Y->X 0.0001\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41876111\n",
      "Epoch [990/1000], Loss: 1.01589746\n",
      "TE: X->Y 0.4025\n",
      "Epoch [990/1000], Loss: 1.42020365\n",
      "Epoch [990/1000], Loss: 1.42020954\n",
      "TE: Y->X -0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.83151251\n",
      "Epoch [990/1000], Loss: 2.41968262\n",
      "TE: X->Y 0.4112\n",
      "Epoch [990/1000], Loss: 2.82360342\n",
      "Epoch [990/1000], Loss: 2.82149154\n",
      "TE: Y->X 0.002\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83812123\n",
      "Epoch [990/1000], Loss: 2.43500601\n",
      "TE: X->Y 0.4027\n",
      "Epoch [990/1000], Loss: 2.83756197\n",
      "Epoch [990/1000], Loss: 2.83748083\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.24389356\n",
      "Epoch [990/1000], Loss: 3.82708571\n",
      "TE: X->Y 0.416\n",
      "Epoch [990/1000], Loss: 4.24316788\n",
      "Epoch [990/1000], Loss: 4.24071323\n",
      "TE: Y->X 0.0022\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.25324971\n",
      "Epoch [990/1000], Loss: 3.84937618\n",
      "TE: X->Y 0.4035\n",
      "Epoch [990/1000], Loss: 4.25688701\n",
      "Epoch [990/1000], Loss: 4.25675904\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.64771944\n",
      "Epoch [990/1000], Loss: 5.22260199\n",
      "TE: X->Y 0.4237\n",
      "Epoch [990/1000], Loss: 5.66191473\n",
      "Epoch [990/1000], Loss: 5.65269029\n",
      "TE: Y->X 0.0087\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.67707253\n",
      "Epoch [990/1000], Loss: 5.27270463\n",
      "TE: X->Y 0.404\n",
      "Epoch [990/1000], Loss: 5.67220463\n",
      "Epoch [990/1000], Loss: 5.67206263\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.05851479\n",
      "Epoch [990/1000], Loss: 6.61316613\n",
      "TE: X->Y 0.4421\n",
      "Epoch [990/1000], Loss: 7.06791351\n",
      "Epoch [990/1000], Loss: 7.06424602\n",
      "TE: Y->X 0.0031\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.09682388\n",
      "Epoch [990/1000], Loss: 6.69208771\n",
      "TE: X->Y 0.4043\n",
      "Epoch [990/1000], Loss: 7.09117077\n",
      "Epoch [990/1000], Loss: 7.09114283\n",
      "TE: Y->X 0.0\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.48259808\n",
      "Epoch [990/1000], Loss: 8.00873697\n",
      "TE: X->Y 0.4682\n",
      "Epoch [990/1000], Loss: 8.47851372\n",
      "Epoch [990/1000], Loss: 8.46426785\n",
      "TE: Y->X 0.0115\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.51445791\n",
      "Epoch [990/1000], Loss: 8.10893425\n",
      "TE: X->Y 0.4051\n",
      "Epoch [990/1000], Loss: 8.51066054\n",
      "Epoch [990/1000], Loss: 8.51020322\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.90687039\n",
      "Epoch [990/1000], Loss: 9.37825477\n",
      "TE: X->Y 0.5192\n",
      "Epoch [990/1000], Loss: 9.87169954\n",
      "Epoch [990/1000], Loss: 9.82746122\n",
      "TE: Y->X 0.0373\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.93375074\n",
      "Epoch [990/1000], Loss: 9.53029926\n",
      "TE: X->Y 0.4029\n",
      "Epoch [990/1000], Loss: 9.92983241\n",
      "Epoch [990/1000], Loss: 9.92954014\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.30467894\n",
      "Epoch [990/1000], Loss: 10.71519979\n",
      "TE: X->Y 0.5724\n",
      "Epoch [990/1000], Loss: 11.25423342\n",
      "Epoch [990/1000], Loss: 11.20885032\n",
      "TE: Y->X 0.0352\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.35239946\n",
      "Epoch [990/1000], Loss: 10.94754516\n",
      "TE: X->Y 0.4042\n",
      "Epoch [990/1000], Loss: 11.34790386\n",
      "Epoch [990/1000], Loss: 11.34734793\n",
      "TE: Y->X 0.0006\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.68369911\n",
      "Epoch [990/1000], Loss: 11.97252577\n",
      "TE: X->Y 0.6841\n",
      "Epoch [990/1000], Loss: 12.63567789\n",
      "Epoch [990/1000], Loss: 12.52417551\n",
      "TE: Y->X 0.0934\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.77185748\n",
      "Epoch [990/1000], Loss: 12.36548201\n",
      "TE: X->Y 0.4058\n",
      "Epoch [990/1000], Loss: 12.76852688\n",
      "Epoch [990/1000], Loss: 12.76807137\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 14.01463625\n",
      "Epoch [990/1000], Loss: 13.21034113\n",
      "TE: X->Y 0.7636\n",
      "Epoch [990/1000], Loss: 14.00562848\n",
      "Epoch [990/1000], Loss: 13.81180469\n",
      "TE: Y->X 0.1651\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.18885661\n",
      "Epoch [990/1000], Loss: 13.78148645\n",
      "TE: X->Y 0.4067\n",
      "Epoch [990/1000], Loss: 14.18798888\n",
      "Epoch [990/1000], Loss: 14.18715338\n",
      "TE: Y->X 0.0008\n"
     ]
    }
   ],
   "source": [
    "jp_results_TE_X2Y = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "jp_results_TE_Y2X = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for dim, generator in zip(dim_range, jp_generator_lst):\n",
    "        print(\"## Dim = \", dim, \"#\")\n",
    "        for samples in sample_sizes:\n",
    "            print(\"# Sample size = \", samples, \"#\")\n",
    "            # Simulate data\n",
    "            dataset = get_dataset(generator, samples, int(np.round(samples/20)), seed=r)\n",
    "            # Estimate X -> Y\n",
    "            TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', dim*16, NB)\n",
    "            jp_results_TE_X2Y.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_X2Y)\n",
    "            # Estimate Y -> X\n",
    "            TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', dim*16, NB)\n",
    "            jp_results_TE_Y2X.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_Y2X)\n",
    "\n",
    "jp_results_TE_X2Y.df.to_csv('results/agmte/jp_results_TE_X2Y_dimred.csv', index=False)\n",
    "jp_results_TE_Y2X.df.to_csv('results/agmte/jp_results_TE_Y2X_dimred.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dimensionality Scaling without redundant dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_range = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n",
    "sample_sizes = [10000, 100000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear Gaussian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the list of generators with one for each dimension\n",
    "lg_generator_lst = [MVLinearGaussianSimulator(n_dim=dim) for dim in dim_range]\n",
    "# Get the reference values\n",
    "lg_TE_X2Y_ref_lst = [generator.analytic_transfer_entropy('X', 'Y') for generator in lg_generator_lst]\n",
    "lg_TE_Y2X_ref_lst = [generator.analytic_transfer_entropy('Y', 'X') for generator in lg_generator_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.59989418\n",
      "Epoch [990/1000], Loss: 0.59963569\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 0.73614373\n",
      "Epoch [990/1000], Loss: 0.60895358\n",
      "TE: Y->X 0.1272\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61211201\n",
      "Epoch [990/1000], Loss: 0.61210904\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73671466\n",
      "Epoch [990/1000], Loss: 0.61274123\n",
      "TE: Y->X 0.124\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.21084309\n",
      "Epoch [990/1000], Loss: 1.20825435\n",
      "TE: X->Y 0.0026\n",
      "Epoch [990/1000], Loss: 1.47636128\n",
      "Epoch [990/1000], Loss: 1.23029444\n",
      "TE: Y->X 0.246\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.22123227\n",
      "Epoch [990/1000], Loss: 1.22110523\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 1.47910843\n",
      "Epoch [990/1000], Loss: 1.22908398\n",
      "TE: Y->X 0.25\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.81824296\n",
      "Epoch [990/1000], Loss: 1.81421819\n",
      "TE: X->Y 0.0038\n",
      "Epoch [990/1000], Loss: 2.22549062\n",
      "Epoch [990/1000], Loss: 1.85308528\n",
      "TE: Y->X 0.3722\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.83303618\n",
      "Epoch [990/1000], Loss: 1.83275313\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 2.21911905\n",
      "Epoch [990/1000], Loss: 1.84309498\n",
      "TE: Y->X 0.3758\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.43626592\n",
      "Epoch [990/1000], Loss: 2.42287187\n",
      "TE: X->Y 0.0128\n",
      "Epoch [990/1000], Loss: 2.94491743\n",
      "Epoch [990/1000], Loss: 2.44453548\n",
      "TE: Y->X 0.5001\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.44395252\n",
      "Epoch [990/1000], Loss: 2.44316268\n",
      "TE: X->Y 0.0008\n",
      "Epoch [990/1000], Loss: 2.95690051\n",
      "Epoch [990/1000], Loss: 2.45396698\n",
      "TE: Y->X 0.5029\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.04768859\n",
      "Epoch [990/1000], Loss: 3.02006927\n",
      "TE: X->Y 0.0263\n",
      "Epoch [990/1000], Loss: 3.66807305\n",
      "Epoch [990/1000], Loss: 3.03289669\n",
      "TE: Y->X 0.6349\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.05406404\n",
      "Epoch [990/1000], Loss: 3.05301419\n",
      "TE: X->Y 0.001\n",
      "Epoch [990/1000], Loss: 3.70090032\n",
      "Epoch [990/1000], Loss: 3.07059861\n",
      "TE: Y->X 0.6303\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.66211672\n",
      "Epoch [990/1000], Loss: 3.61028336\n",
      "TE: X->Y 0.0486\n",
      "Epoch [990/1000], Loss: 4.39805276\n",
      "Epoch [990/1000], Loss: 3.62225643\n",
      "TE: Y->X 0.7754\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.66900743\n",
      "Epoch [990/1000], Loss: 3.66767351\n",
      "TE: X->Y 0.0013\n",
      "Epoch [990/1000], Loss: 4.44030212\n",
      "Epoch [990/1000], Loss: 3.68533757\n",
      "TE: Y->X 0.7549\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.25331818\n",
      "Epoch [990/1000], Loss: 4.15173379\n",
      "TE: X->Y 0.0954\n",
      "Epoch [990/1000], Loss: 5.12406579\n",
      "Epoch [990/1000], Loss: 4.18295938\n",
      "TE: Y->X 0.9392\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.28371295\n",
      "Epoch [990/1000], Loss: 4.28175896\n",
      "TE: X->Y 0.0019\n",
      "Epoch [990/1000], Loss: 5.18382191\n",
      "Epoch [990/1000], Loss: 4.29758512\n",
      "TE: Y->X 0.8861\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.85559839\n",
      "Epoch [990/1000], Loss: 4.66744362\n",
      "TE: X->Y 0.1745\n",
      "Epoch [990/1000], Loss: 5.79226028\n",
      "Epoch [990/1000], Loss: 4.68656655\n",
      "TE: Y->X 1.1012\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.89779489\n",
      "Epoch [990/1000], Loss: 4.89524422\n",
      "TE: X->Y 0.0025\n",
      "Epoch [990/1000], Loss: 5.92060049\n",
      "Epoch [990/1000], Loss: 4.90836228\n",
      "TE: Y->X 1.0121\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.46351269\n",
      "Epoch [990/1000], Loss: 5.16328366\n",
      "TE: X->Y 0.2773\n",
      "Epoch [990/1000], Loss: 6.44359568\n",
      "Epoch [990/1000], Loss: 5.16337342\n",
      "TE: Y->X 1.2742\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.51241732\n",
      "Epoch [990/1000], Loss: 5.50883925\n",
      "TE: X->Y 0.0034\n",
      "Epoch [990/1000], Loss: 6.66214669\n",
      "Epoch [990/1000], Loss: 5.51891708\n",
      "TE: Y->X 1.1424\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.05962595\n",
      "Epoch [990/1000], Loss: 5.57538172\n",
      "TE: X->Y 0.4465\n",
      "Epoch [990/1000], Loss: 7.07753682\n",
      "Epoch [990/1000], Loss: 5.54664872\n",
      "TE: Y->X 1.5191\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.12565398\n",
      "Epoch [990/1000], Loss: 6.12035213\n",
      "TE: X->Y 0.0051\n",
      "Epoch [990/1000], Loss: 7.40688726\n",
      "Epoch [990/1000], Loss: 6.13570302\n",
      "TE: Y->X 1.271\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61108281\n",
      "Epoch [990/1000], Loss: 0.61063857\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 0.74077127\n",
      "Epoch [990/1000], Loss: 0.62222807\n",
      "TE: Y->X 0.1185\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.60909524\n",
      "Epoch [990/1000], Loss: 0.60912331\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.74198853\n",
      "Epoch [990/1000], Loss: 0.61596876\n",
      "TE: Y->X 0.126\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.21940936\n",
      "Epoch [990/1000], Loss: 1.21809849\n",
      "TE: X->Y 0.0013\n",
      "Epoch [990/1000], Loss: 1.49165873\n",
      "Epoch [990/1000], Loss: 1.24693124\n",
      "TE: Y->X 0.2446\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.22094614\n",
      "Epoch [990/1000], Loss: 1.22080929\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 1.48141594\n",
      "Epoch [990/1000], Loss: 1.22952839\n",
      "TE: Y->X 0.2519\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.83813393\n",
      "Epoch [990/1000], Loss: 1.83319775\n",
      "TE: X->Y 0.0049\n",
      "Epoch [990/1000], Loss: 2.21177848\n",
      "Epoch [990/1000], Loss: 1.84596267\n",
      "TE: Y->X 0.3658\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.83195404\n",
      "Epoch [990/1000], Loss: 1.83157075\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 2.21923613\n",
      "Epoch [990/1000], Loss: 1.84166091\n",
      "TE: Y->X 0.3776\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.45106157\n",
      "Epoch [990/1000], Loss: 2.43853938\n",
      "TE: X->Y 0.0121\n",
      "Epoch [990/1000], Loss: 2.94450314\n",
      "Epoch [990/1000], Loss: 2.44460542\n",
      "TE: Y->X 0.5\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.44206007\n",
      "Epoch [990/1000], Loss: 2.44146944\n",
      "TE: X->Y 0.0006\n",
      "Epoch [990/1000], Loss: 2.96399052\n",
      "Epoch [990/1000], Loss: 2.45845603\n",
      "TE: Y->X 0.5055\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.06732774\n",
      "Epoch [990/1000], Loss: 3.04260678\n",
      "TE: X->Y 0.0235\n",
      "Epoch [990/1000], Loss: 3.68329165\n",
      "Epoch [990/1000], Loss: 3.04694783\n",
      "TE: Y->X 0.6363\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.05700259\n",
      "Epoch [990/1000], Loss: 3.05625162\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 3.70326753\n",
      "Epoch [990/1000], Loss: 3.07354029\n",
      "TE: Y->X 0.6297\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.65984825\n",
      "Epoch [990/1000], Loss: 3.59968414\n",
      "TE: X->Y 0.0568\n",
      "Epoch [990/1000], Loss: 4.41703381\n",
      "Epoch [990/1000], Loss: 3.64359324\n",
      "TE: Y->X 0.7729\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.67192234\n",
      "Epoch [990/1000], Loss: 3.67036886\n",
      "TE: X->Y 0.0015\n",
      "Epoch [990/1000], Loss: 4.44740224\n",
      "Epoch [990/1000], Loss: 3.68616229\n",
      "TE: Y->X 0.7612\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.26352982\n",
      "Epoch [990/1000], Loss: 4.16560038\n",
      "TE: X->Y 0.0909\n",
      "Epoch [990/1000], Loss: 5.10578299\n",
      "Epoch [990/1000], Loss: 4.17484777\n",
      "TE: Y->X 0.9299\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.28596081\n",
      "Epoch [990/1000], Loss: 4.28403539\n",
      "TE: X->Y 0.0019\n",
      "Epoch [990/1000], Loss: 5.18391218\n",
      "Epoch [990/1000], Loss: 4.29695129\n",
      "TE: Y->X 0.8869\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.87524562\n",
      "Epoch [990/1000], Loss: 4.70972859\n",
      "TE: X->Y 0.1536\n",
      "Epoch [990/1000], Loss: 5.81453534\n",
      "Epoch [990/1000], Loss: 4.70941397\n",
      "TE: Y->X 1.101\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.90066354\n",
      "Epoch [990/1000], Loss: 4.89792238\n",
      "TE: X->Y 0.0027\n",
      "Epoch [990/1000], Loss: 5.92494478\n",
      "Epoch [990/1000], Loss: 4.90787289\n",
      "TE: Y->X 1.0173\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.47443662\n",
      "Epoch [990/1000], Loss: 5.16586087\n",
      "TE: X->Y 0.2846\n",
      "Epoch [990/1000], Loss: 6.48942294\n",
      "Epoch [990/1000], Loss: 5.16832708\n",
      "TE: Y->X 1.3122\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.51383911\n",
      "Epoch [990/1000], Loss: 5.51057502\n",
      "TE: X->Y 0.0031\n",
      "Epoch [990/1000], Loss: 6.67111105\n",
      "Epoch [990/1000], Loss: 5.52495607\n",
      "TE: Y->X 1.146\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.06362442\n",
      "Epoch [990/1000], Loss: 5.59353657\n",
      "TE: X->Y 0.4328\n",
      "Epoch [990/1000], Loss: 7.09810119\n",
      "Epoch [990/1000], Loss: 5.56832643\n",
      "TE: Y->X 1.5182\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.12422831\n",
      "Epoch [990/1000], Loss: 6.11915849\n",
      "TE: X->Y 0.0049\n",
      "Epoch [990/1000], Loss: 7.41389894\n",
      "Epoch [990/1000], Loss: 6.13904866\n",
      "TE: Y->X 1.2746\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.60880091\n",
      "Epoch [990/1000], Loss: 0.60902315\n",
      "TE: X->Y -0.0002\n",
      "Epoch [990/1000], Loss: 0.75317364\n",
      "Epoch [990/1000], Loss: 0.62735875\n",
      "TE: Y->X 0.1258\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61182394\n",
      "Epoch [990/1000], Loss: 0.61181503\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73931785\n",
      "Epoch [990/1000], Loss: 0.61327246\n",
      "TE: Y->X 0.126\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.22809118\n",
      "Epoch [990/1000], Loss: 1.22591404\n",
      "TE: X->Y 0.0022\n",
      "Epoch [990/1000], Loss: 1.47416767\n",
      "Epoch [990/1000], Loss: 1.22666465\n",
      "TE: Y->X 0.2475\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.22280383\n",
      "Epoch [990/1000], Loss: 1.22271973\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 1.47630912\n",
      "Epoch [990/1000], Loss: 1.22518093\n",
      "TE: Y->X 0.2511\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.84139575\n",
      "Epoch [990/1000], Loss: 1.83614419\n",
      "TE: X->Y 0.0051\n",
      "Epoch [990/1000], Loss: 2.21061247\n",
      "Epoch [990/1000], Loss: 1.83015259\n",
      "TE: Y->X 0.3804\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.83293364\n",
      "Epoch [990/1000], Loss: 1.83274963\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 2.22087291\n",
      "Epoch [990/1000], Loss: 1.84288109\n",
      "TE: Y->X 0.378\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.45855918\n",
      "Epoch [990/1000], Loss: 2.44724231\n",
      "TE: X->Y 0.0109\n",
      "Epoch [990/1000], Loss: 2.94859342\n",
      "Epoch [990/1000], Loss: 2.44083648\n",
      "TE: Y->X 0.5079\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.44796401\n",
      "Epoch [990/1000], Loss: 2.44750679\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 2.96092172\n",
      "Epoch [990/1000], Loss: 2.45808881\n",
      "TE: Y->X 0.5028\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.05278485\n",
      "Epoch [990/1000], Loss: 3.02723748\n",
      "TE: X->Y 0.0244\n",
      "Epoch [990/1000], Loss: 3.68989867\n",
      "Epoch [990/1000], Loss: 3.03901929\n",
      "TE: Y->X 0.6505\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.06296786\n",
      "Epoch [990/1000], Loss: 3.06204064\n",
      "TE: X->Y 0.0009\n",
      "Epoch [990/1000], Loss: 3.70511803\n",
      "Epoch [990/1000], Loss: 3.07040973\n",
      "TE: Y->X 0.6347\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.65773736\n",
      "Epoch [990/1000], Loss: 3.60362737\n",
      "TE: X->Y 0.0509\n",
      "Epoch [990/1000], Loss: 4.40929335\n",
      "Epoch [990/1000], Loss: 3.60394595\n",
      "TE: Y->X 0.8034\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.67692526\n",
      "Epoch [990/1000], Loss: 3.67565511\n",
      "TE: X->Y 0.0012\n",
      "Epoch [990/1000], Loss: 4.44204935\n",
      "Epoch [990/1000], Loss: 3.68185362\n",
      "TE: Y->X 0.7603\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.26954658\n",
      "Epoch [990/1000], Loss: 4.16755924\n",
      "TE: X->Y 0.0952\n",
      "Epoch [990/1000], Loss: 5.12896359\n",
      "Epoch [990/1000], Loss: 4.16731697\n",
      "TE: Y->X 0.9588\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.29172616\n",
      "Epoch [990/1000], Loss: 4.29018295\n",
      "TE: X->Y 0.0015\n",
      "Epoch [990/1000], Loss: 5.18311905\n",
      "Epoch [990/1000], Loss: 4.29315592\n",
      "TE: Y->X 0.8902\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.87149689\n",
      "Epoch [990/1000], Loss: 4.69744344\n",
      "TE: X->Y 0.162\n",
      "Epoch [990/1000], Loss: 5.83980287\n",
      "Epoch [990/1000], Loss: 4.68057874\n",
      "TE: Y->X 1.1535\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.90490621\n",
      "Epoch [990/1000], Loss: 4.90218533\n",
      "TE: X->Y 0.0026\n",
      "Epoch [990/1000], Loss: 5.92944433\n",
      "Epoch [990/1000], Loss: 4.91059232\n",
      "TE: Y->X 1.0187\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.46307761\n",
      "Epoch [990/1000], Loss: 5.18149378\n",
      "TE: X->Y 0.2609\n",
      "Epoch [990/1000], Loss: 6.48595077\n",
      "Epoch [990/1000], Loss: 5.16716038\n",
      "TE: Y->X 1.3122\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.51530485\n",
      "Epoch [990/1000], Loss: 5.51140549\n",
      "TE: X->Y 0.0038\n",
      "Epoch [990/1000], Loss: 6.67257349\n",
      "Epoch [990/1000], Loss: 5.52477699\n",
      "TE: Y->X 1.1476\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.06107618\n",
      "Epoch [990/1000], Loss: 5.55937279\n",
      "TE: X->Y 0.4641\n",
      "Epoch [990/1000], Loss: 7.09921863\n",
      "Epoch [990/1000], Loss: 5.52207629\n",
      "TE: Y->X 1.5598\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.12606292\n",
      "Epoch [990/1000], Loss: 6.12160002\n",
      "TE: X->Y 0.0043\n",
      "Epoch [990/1000], Loss: 7.41265651\n",
      "Epoch [990/1000], Loss: 6.13811561\n",
      "TE: Y->X 1.2743\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61946545\n",
      "Epoch [990/1000], Loss: 0.61925064\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 0.72131309\n",
      "Epoch [990/1000], Loss: 0.59928024\n",
      "TE: Y->X 0.122\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61095589\n",
      "Epoch [990/1000], Loss: 0.61093201\n",
      "TE: X->Y 0.0\n",
      "Epoch [990/1000], Loss: 0.73680341\n",
      "Epoch [990/1000], Loss: 0.61120163\n",
      "TE: Y->X 0.1256\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.23309767\n",
      "Epoch [990/1000], Loss: 1.23145933\n",
      "TE: X->Y 0.0016\n",
      "Epoch [990/1000], Loss: 1.45802173\n",
      "Epoch [990/1000], Loss: 1.20460982\n",
      "TE: Y->X 0.2533\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.22116745\n",
      "Epoch [990/1000], Loss: 1.22107934\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 1.48081055\n",
      "Epoch [990/1000], Loss: 1.22955996\n",
      "TE: Y->X 0.251\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.85127505\n",
      "Epoch [990/1000], Loss: 1.84608714\n",
      "TE: X->Y 0.0051\n",
      "Epoch [990/1000], Loss: 2.19878038\n",
      "Epoch [990/1000], Loss: 1.82072899\n",
      "TE: Y->X 0.3783\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.83627453\n",
      "Epoch [990/1000], Loss: 1.83595722\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 2.22015091\n",
      "Epoch [990/1000], Loss: 1.84533441\n",
      "TE: Y->X 0.3748\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.44720462\n",
      "Epoch [990/1000], Loss: 2.43510517\n",
      "TE: X->Y 0.0117\n",
      "Epoch [990/1000], Loss: 2.93987771\n",
      "Epoch [990/1000], Loss: 2.43419775\n",
      "TE: Y->X 0.5059\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.45116875\n",
      "Epoch [990/1000], Loss: 2.45077116\n",
      "TE: X->Y 0.0004\n",
      "Epoch [990/1000], Loss: 2.96551581\n",
      "Epoch [990/1000], Loss: 2.45802639\n",
      "TE: Y->X 0.5075\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.05188073\n",
      "Epoch [990/1000], Loss: 3.02556579\n",
      "TE: X->Y 0.0249\n",
      "Epoch [990/1000], Loss: 3.65769017\n",
      "Epoch [990/1000], Loss: 3.01520835\n",
      "TE: Y->X 0.6425\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.06528749\n",
      "Epoch [990/1000], Loss: 3.06456456\n",
      "TE: X->Y 0.0007\n",
      "Epoch [990/1000], Loss: 3.70250507\n",
      "Epoch [990/1000], Loss: 3.06989554\n",
      "TE: Y->X 0.6326\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.66678824\n",
      "Epoch [990/1000], Loss: 3.60922369\n",
      "TE: X->Y 0.0543\n",
      "Epoch [990/1000], Loss: 4.39377878\n",
      "Epoch [990/1000], Loss: 3.60350773\n",
      "TE: Y->X 0.7897\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.68005651\n",
      "Epoch [990/1000], Loss: 3.67880091\n",
      "TE: X->Y 0.0012\n",
      "Epoch [990/1000], Loss: 4.44373343\n",
      "Epoch [990/1000], Loss: 3.68076056\n",
      "TE: Y->X 0.7629\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.27263269\n",
      "Epoch [990/1000], Loss: 4.16940057\n",
      "TE: X->Y 0.0962\n",
      "Epoch [990/1000], Loss: 5.12796772\n",
      "Epoch [990/1000], Loss: 4.16072088\n",
      "TE: Y->X 0.9648\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.29330315\n",
      "Epoch [990/1000], Loss: 4.29121665\n",
      "TE: X->Y 0.002\n",
      "Epoch [990/1000], Loss: 5.19006927\n",
      "Epoch [990/1000], Loss: 4.29846225\n",
      "TE: Y->X 0.8915\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.87155409\n",
      "Epoch [990/1000], Loss: 4.68680661\n",
      "TE: X->Y 0.1721\n",
      "Epoch [990/1000], Loss: 5.82342761\n",
      "Epoch [990/1000], Loss: 4.68107404\n",
      "TE: Y->X 1.1407\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.90379386\n",
      "Epoch [990/1000], Loss: 4.90089231\n",
      "TE: X->Y 0.0028\n",
      "Epoch [990/1000], Loss: 5.93399215\n",
      "Epoch [990/1000], Loss: 4.91380256\n",
      "TE: Y->X 1.0201\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.47021385\n",
      "Epoch [990/1000], Loss: 5.16173653\n",
      "TE: X->Y 0.2855\n",
      "Epoch [990/1000], Loss: 6.48219392\n",
      "Epoch [990/1000], Loss: 5.12905489\n",
      "TE: Y->X 1.3442\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.51446811\n",
      "Epoch [990/1000], Loss: 5.51063932\n",
      "TE: X->Y 0.0037\n",
      "Epoch [990/1000], Loss: 6.67393772\n",
      "Epoch [990/1000], Loss: 5.52804716\n",
      "TE: Y->X 1.1457\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.06544938\n",
      "Epoch [990/1000], Loss: 5.58527824\n",
      "TE: X->Y 0.4432\n",
      "Epoch [990/1000], Loss: 7.08264485\n",
      "Epoch [990/1000], Loss: 5.52028267\n",
      "TE: Y->X 1.5462\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.12700027\n",
      "Epoch [990/1000], Loss: 6.12248036\n",
      "TE: X->Y 0.0043\n",
      "Epoch [990/1000], Loss: 7.41555957\n",
      "Epoch [990/1000], Loss: 6.14068505\n",
      "TE: Y->X 1.2746\n",
      "\n",
      "### REPLICATE 5/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 0.61368827\n",
      "Epoch [990/1000], Loss: 0.61354231\n",
      "TE: X->Y 0.0001\n",
      "Epoch [990/1000], Loss: 0.73660207\n",
      "Epoch [990/1000], Loss: 0.60476839\n",
      "TE: Y->X 0.1318\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 0.61023317\n",
      "Epoch [990/1000], Loss: 0.61024826\n",
      "TE: X->Y -0.0\n",
      "Epoch [990/1000], Loss: 0.74372039\n",
      "Epoch [990/1000], Loss: 0.61738689\n",
      "TE: Y->X 0.1263\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.23190419\n",
      "Epoch [990/1000], Loss: 1.23047297\n",
      "TE: X->Y 0.0014\n",
      "Epoch [990/1000], Loss: 1.47766563\n",
      "Epoch [990/1000], Loss: 1.22274656\n",
      "TE: Y->X 0.2549\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.22534484\n",
      "Epoch [990/1000], Loss: 1.22518613\n",
      "TE: X->Y 0.0002\n",
      "Epoch [990/1000], Loss: 1.48289137\n",
      "Epoch [990/1000], Loss: 1.23345958\n",
      "TE: Y->X 0.2494\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.82842089\n",
      "Epoch [990/1000], Loss: 1.82497103\n",
      "TE: X->Y 0.0033\n",
      "Epoch [990/1000], Loss: 2.22138769\n",
      "Epoch [990/1000], Loss: 1.84022132\n",
      "TE: Y->X 0.3812\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.84044254\n",
      "Epoch [990/1000], Loss: 1.84013122\n",
      "TE: X->Y 0.0003\n",
      "Epoch [990/1000], Loss: 2.22711218\n",
      "Epoch [990/1000], Loss: 1.84699569\n",
      "TE: Y->X 0.3801\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.43492239\n",
      "Epoch [990/1000], Loss: 2.42097879\n",
      "TE: X->Y 0.0133\n",
      "Epoch [990/1000], Loss: 2.94706146\n",
      "Epoch [990/1000], Loss: 2.42830137\n",
      "TE: Y->X 0.5186\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.45456022\n",
      "Epoch [990/1000], Loss: 2.45403904\n",
      "TE: X->Y 0.0005\n",
      "Epoch [990/1000], Loss: 2.96457766\n",
      "Epoch [990/1000], Loss: 2.45920043\n",
      "TE: Y->X 0.5054\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.05056791\n",
      "Epoch [990/1000], Loss: 3.02447455\n",
      "TE: X->Y 0.0248\n",
      "Epoch [990/1000], Loss: 3.68125779\n",
      "Epoch [990/1000], Loss: 3.03591867\n",
      "TE: Y->X 0.6458\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.06947475\n",
      "Epoch [990/1000], Loss: 3.06853083\n",
      "TE: X->Y 0.0009\n",
      "Epoch [990/1000], Loss: 3.70637859\n",
      "Epoch [990/1000], Loss: 3.07071444\n",
      "TE: Y->X 0.6356\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 3.65333152\n",
      "Epoch [990/1000], Loss: 3.60031381\n",
      "TE: X->Y 0.0499\n",
      "Epoch [990/1000], Loss: 4.42776608\n",
      "Epoch [990/1000], Loss: 3.61917169\n",
      "TE: Y->X 0.8076\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 3.68272417\n",
      "Epoch [990/1000], Loss: 3.68145212\n",
      "TE: X->Y 0.0012\n",
      "Epoch [990/1000], Loss: 4.45347021\n",
      "Epoch [990/1000], Loss: 3.68885681\n",
      "TE: Y->X 0.7646\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.25556177\n",
      "Epoch [990/1000], Loss: 4.14618845\n",
      "TE: X->Y 0.1025\n",
      "Epoch [990/1000], Loss: 5.15637215\n",
      "Epoch [990/1000], Loss: 4.17652171\n",
      "TE: Y->X 0.9761\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.29323097\n",
      "Epoch [990/1000], Loss: 4.29091512\n",
      "TE: X->Y 0.0023\n",
      "Epoch [990/1000], Loss: 5.19690882\n",
      "Epoch [990/1000], Loss: 4.30437737\n",
      "TE: Y->X 0.8924\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.85700067\n",
      "Epoch [990/1000], Loss: 4.66336852\n",
      "TE: X->Y 0.1807\n",
      "Epoch [990/1000], Loss: 5.84135376\n",
      "Epoch [990/1000], Loss: 4.70124016\n",
      "TE: Y->X 1.1344\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.90392491\n",
      "Epoch [990/1000], Loss: 4.90126112\n",
      "TE: X->Y 0.0026\n",
      "Epoch [990/1000], Loss: 5.93691031\n",
      "Epoch [990/1000], Loss: 4.91845045\n",
      "TE: Y->X 1.0183\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.46109962\n",
      "Epoch [990/1000], Loss: 5.14052735\n",
      "TE: X->Y 0.2971\n",
      "Epoch [990/1000], Loss: 6.49749444\n",
      "Epoch [990/1000], Loss: 5.13883848\n",
      "TE: Y->X 1.3471\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.51641826\n",
      "Epoch [990/1000], Loss: 5.51292533\n",
      "TE: X->Y 0.0034\n",
      "Epoch [990/1000], Loss: 6.67909868\n",
      "Epoch [990/1000], Loss: 5.53197192\n",
      "TE: Y->X 1.147\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 6.06442969\n",
      "Epoch [990/1000], Loss: 5.57048416\n",
      "TE: X->Y 0.4581\n",
      "Epoch [990/1000], Loss: 7.08468292\n",
      "Epoch [990/1000], Loss: 5.54510194\n",
      "TE: Y->X 1.5278\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 6.13240046\n",
      "Epoch [990/1000], Loss: 6.12792094\n",
      "TE: X->Y 0.0043\n",
      "Epoch [990/1000], Loss: 7.42233545\n",
      "Epoch [990/1000], Loss: 6.14619236\n",
      "TE: Y->X 1.2759\n"
     ]
    }
   ],
   "source": [
    "lg_results_TE_X2Y = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "lg_results_TE_Y2X = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for dim, generator in zip(dim_range, lg_generator_lst):\n",
    "        print(\"## Dim = \", dim, \"#\")\n",
    "        for samples in sample_sizes:\n",
    "            print(\"# Sample size = \", samples, \"#\")\n",
    "            # Simulate data\n",
    "            dataset = get_dataset(generator, samples, int(np.round(samples/20)), seed=r)\n",
    "            # Estimate X -> Y\n",
    "            TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', dim*16, NB)\n",
    "            lg_results_TE_X2Y.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_X2Y)\n",
    "            # Estimate Y -> X\n",
    "            TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', dim*16, NB)\n",
    "            lg_results_TE_Y2X.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_Y2X)\n",
    "\n",
    "lg_results_TE_X2Y.df.to_csv('results/agmte/lg_results_TE_X2Y_dim.csv', index=False)\n",
    "lg_results_TE_Y2X.df.to_csv('results/agmte/lg_results_TE_Y2X_dim.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Joint Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the list of generators with one for each dimension\n",
    "jp_generator_lst = [MVJointProcessSimulator(n_dim=dim, lam=0.0) for dim in dim_range]\n",
    "# Get the reference values\n",
    "jp_TE_X2Y_ref_lst = [generator.analytic_transfer_entropy('X', 'Y') for generator in jp_generator_lst]\n",
    "jp_TE_Y2X_ref_lst = [generator.analytic_transfer_entropy('Y', 'X') for generator in jp_generator_lst]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "### REPLICATE 1/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.40368636\n",
      "Epoch [990/1000], Loss: 1.00622459\n",
      "TE: X->Y 0.3969\n",
      "Epoch [990/1000], Loss: 1.41204164\n",
      "Epoch [990/1000], Loss: 1.41193434\n",
      "TE: Y->X 0.0001\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41729889\n",
      "Epoch [990/1000], Loss: 1.01296616\n",
      "TE: X->Y 0.4039\n",
      "Epoch [990/1000], Loss: 1.41759848\n",
      "Epoch [990/1000], Loss: 1.41758547\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.82892699\n",
      "Epoch [990/1000], Loss: 2.01031649\n",
      "TE: X->Y 0.8176\n",
      "Epoch [990/1000], Loss: 2.83087721\n",
      "Epoch [990/1000], Loss: 2.83018205\n",
      "TE: Y->X 0.0007\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83434137\n",
      "Epoch [990/1000], Loss: 2.03042294\n",
      "TE: X->Y 0.803\n",
      "Epoch [990/1000], Loss: 2.83347221\n",
      "Epoch [990/1000], Loss: 2.83340401\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.25337291\n",
      "Epoch [990/1000], Loss: 3.00916057\n",
      "TE: X->Y 1.2417\n",
      "Epoch [990/1000], Loss: 4.25667743\n",
      "Epoch [990/1000], Loss: 4.25461478\n",
      "TE: Y->X 0.0019\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.24824857\n",
      "Epoch [990/1000], Loss: 3.05529461\n",
      "TE: X->Y 1.1913\n",
      "Epoch [990/1000], Loss: 4.24649386\n",
      "Epoch [990/1000], Loss: 4.24634628\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.66564253\n",
      "Epoch [990/1000], Loss: 3.97631094\n",
      "TE: X->Y 1.6842\n",
      "Epoch [990/1000], Loss: 5.66635198\n",
      "Epoch [990/1000], Loss: 5.66066147\n",
      "TE: Y->X 0.0053\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.66361018\n",
      "Epoch [990/1000], Loss: 4.08083146\n",
      "TE: X->Y 1.5814\n",
      "Epoch [990/1000], Loss: 5.66150145\n",
      "Epoch [990/1000], Loss: 5.66132222\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.07051769\n",
      "Epoch [990/1000], Loss: 4.89997383\n",
      "TE: X->Y 2.157\n",
      "Epoch [990/1000], Loss: 7.07247155\n",
      "Epoch [990/1000], Loss: 7.06460087\n",
      "TE: Y->X 0.0069\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.08204753\n",
      "Epoch [990/1000], Loss: 5.10499063\n",
      "TE: X->Y 1.974\n",
      "Epoch [990/1000], Loss: 7.08156294\n",
      "Epoch [990/1000], Loss: 7.08111858\n",
      "TE: Y->X 0.0004\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.48497781\n",
      "Epoch [990/1000], Loss: 5.76594253\n",
      "TE: X->Y 2.7001\n",
      "Epoch [990/1000], Loss: 8.48954929\n",
      "Epoch [990/1000], Loss: 8.46394795\n",
      "TE: Y->X 0.0224\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.49843439\n",
      "Epoch [990/1000], Loss: 6.11686046\n",
      "TE: X->Y 2.3774\n",
      "Epoch [990/1000], Loss: 8.49905944\n",
      "Epoch [990/1000], Loss: 8.49852834\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.87081579\n",
      "Epoch [990/1000], Loss: 6.58505879\n",
      "TE: X->Y 3.2596\n",
      "Epoch [990/1000], Loss: 9.87758751\n",
      "Epoch [990/1000], Loss: 9.84524439\n",
      "TE: Y->X 0.0264\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.91243886\n",
      "Epoch [990/1000], Loss: 7.15352285\n",
      "TE: X->Y 2.7531\n",
      "Epoch [990/1000], Loss: 9.91898299\n",
      "Epoch [990/1000], Loss: 9.91845136\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.24224734\n",
      "Epoch [990/1000], Loss: 7.40902874\n",
      "TE: X->Y 3.7981\n",
      "Epoch [990/1000], Loss: 11.25421955\n",
      "Epoch [990/1000], Loss: 11.21716383\n",
      "TE: Y->X 0.0283\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.32960249\n",
      "Epoch [990/1000], Loss: 8.17242338\n",
      "TE: X->Y 3.1522\n",
      "Epoch [990/1000], Loss: 11.33960715\n",
      "Epoch [990/1000], Loss: 11.33880771\n",
      "TE: Y->X 0.0008\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.60452736\n",
      "Epoch [990/1000], Loss: 8.132383183\n",
      "TE: X->Y 4.4195\n",
      "Epoch [990/1000], Loss: 12.62741779\n",
      "Epoch [990/1000], Loss: 12.60076583\n",
      "TE: Y->X 0.0134\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.74733593\n",
      "Epoch [990/1000], Loss: 9.211829385\n",
      "TE: X->Y 3.5302\n",
      "Epoch [990/1000], Loss: 12.75628378\n",
      "Epoch [990/1000], Loss: 12.75551951\n",
      "TE: Y->X 0.0008\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.95076446\n",
      "Epoch [990/1000], Loss: 8.776438279\n",
      "TE: X->Y 5.1084\n",
      "Epoch [990/1000], Loss: 13.96550258\n",
      "Epoch [990/1000], Loss: 13.84473662\n",
      "TE: Y->X 0.0982\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.16712027\n",
      "Epoch [990/1000], Loss: 10.23824758\n",
      "TE: X->Y 3.92\n",
      "Epoch [990/1000], Loss: 14.17613147\n",
      "Epoch [990/1000], Loss: 14.17529247\n",
      "TE: Y->X 0.0009\n",
      "\n",
      "### REPLICATE 2/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.42533749\n",
      "Epoch [990/1000], Loss: 1.00668097\n",
      "TE: X->Y 0.4184\n",
      "Epoch [990/1000], Loss: 1.41964508\n",
      "Epoch [990/1000], Loss: 1.41964654\n",
      "TE: Y->X -0.0\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41706187\n",
      "Epoch [990/1000], Loss: 1.01404106\n",
      "TE: X->Y 0.4027\n",
      "Epoch [990/1000], Loss: 1.41591597\n",
      "Epoch [990/1000], Loss: 1.41586949\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.85065339\n",
      "Epoch [990/1000], Loss: 2.01974507\n",
      "TE: X->Y 0.8298\n",
      "Epoch [990/1000], Loss: 2.84578555\n",
      "Epoch [990/1000], Loss: 2.84508193\n",
      "TE: Y->X 0.0007\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83104427\n",
      "Epoch [990/1000], Loss: 2.03406735\n",
      "TE: X->Y 0.7962\n",
      "Epoch [990/1000], Loss: 2.82901878\n",
      "Epoch [990/1000], Loss: 2.82894378\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.26320534\n",
      "Epoch [990/1000], Loss: 3.01807761\n",
      "TE: X->Y 1.2425\n",
      "Epoch [990/1000], Loss: 4.25694273\n",
      "Epoch [990/1000], Loss: 4.25502529\n",
      "TE: Y->X 0.0018\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.24637219\n",
      "Epoch [990/1000], Loss: 3.05949462\n",
      "TE: X->Y 1.1845\n",
      "Epoch [990/1000], Loss: 4.24404448\n",
      "Epoch [990/1000], Loss: 4.24401284\n",
      "TE: Y->X 0.0\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.66990472\n",
      "Epoch [990/1000], Loss: 3.96592858\n",
      "TE: X->Y 1.6983\n",
      "Epoch [990/1000], Loss: 5.66480516\n",
      "Epoch [990/1000], Loss: 5.66291074\n",
      "TE: Y->X 0.0015\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.66490889\n",
      "Epoch [990/1000], Loss: 4.08640469\n",
      "TE: X->Y 1.5748\n",
      "Epoch [990/1000], Loss: 5.66417083\n",
      "Epoch [990/1000], Loss: 5.66402616\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.08980796\n",
      "Epoch [990/1000], Loss: 4.90225756\n",
      "TE: X->Y 2.1767\n",
      "Epoch [990/1000], Loss: 7.08139341\n",
      "Epoch [990/1000], Loss: 7.07837861\n",
      "TE: Y->X 0.0021\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.08167398\n",
      "Epoch [990/1000], Loss: 5.09887278\n",
      "TE: X->Y 1.9776\n",
      "Epoch [990/1000], Loss: 7.08164081\n",
      "Epoch [990/1000], Loss: 7.08145092\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.48808762\n",
      "Epoch [990/1000], Loss: 5.77420313\n",
      "TE: X->Y 2.6945\n",
      "Epoch [990/1000], Loss: 8.48288558\n",
      "Epoch [990/1000], Loss: 8.45867064\n",
      "TE: Y->X 0.0212\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.49542618\n",
      "Epoch [990/1000], Loss: 6.14096998\n",
      "TE: X->Y 2.3506\n",
      "Epoch [990/1000], Loss: 8.50190916\n",
      "Epoch [990/1000], Loss: 8.50132445\n",
      "TE: Y->X 0.0006\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.87464955\n",
      "Epoch [990/1000], Loss: 6.65050121\n",
      "TE: X->Y 3.1972\n",
      "Epoch [990/1000], Loss: 9.86728585\n",
      "Epoch [990/1000], Loss: 9.83550025\n",
      "TE: Y->X 0.0263\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.91258412\n",
      "Epoch [990/1000], Loss: 7.16058766\n",
      "TE: X->Y 2.7466\n",
      "Epoch [990/1000], Loss: 9.92244508\n",
      "Epoch [990/1000], Loss: 9.92175448\n",
      "TE: Y->X 0.0007\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.27206112\n",
      "Epoch [990/1000], Loss: 7.40799946\n",
      "TE: X->Y 3.8291\n",
      "Epoch [990/1000], Loss: 11.26543009\n",
      "Epoch [990/1000], Loss: 11.21513587\n",
      "TE: Y->X 0.0412\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.33062672\n",
      "Epoch [990/1000], Loss: 8.19400944\n",
      "TE: X->Y 3.134\n",
      "Epoch [990/1000], Loss: 11.33909529\n",
      "Epoch [990/1000], Loss: 11.33841204\n",
      "TE: Y->X 0.0007\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.60864969\n",
      "Epoch [990/1000], Loss: 8.149738639\n",
      "TE: X->Y 4.4081\n",
      "Epoch [990/1000], Loss: 12.63676638\n",
      "Epoch [990/1000], Loss: 12.52997216\n",
      "TE: Y->X 0.0885\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.75060193\n",
      "Epoch [990/1000], Loss: 9.215577658\n",
      "TE: X->Y 3.5249\n",
      "Epoch [990/1000], Loss: 12.75929045\n",
      "Epoch [990/1000], Loss: 12.75857074\n",
      "TE: Y->X 0.0007\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.99471955\n",
      "Epoch [990/1000], Loss: 8.777292751\n",
      "TE: X->Y 5.1522\n",
      "Epoch [990/1000], Loss: 14.01550979\n",
      "Epoch [990/1000], Loss: 13.85757022\n",
      "TE: Y->X 0.1331\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.16672273\n",
      "Epoch [990/1000], Loss: 10.23319209\n",
      "TE: X->Y 3.9235\n",
      "Epoch [990/1000], Loss: 14.17628845\n",
      "Epoch [990/1000], Loss: 14.17524417\n",
      "TE: Y->X 0.001\n",
      "\n",
      "### REPLICATE 3/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.42637168\n",
      "Epoch [990/1000], Loss: 1.01828471\n",
      "TE: X->Y 0.4076\n",
      "Epoch [990/1000], Loss: 1.42696268\n",
      "Epoch [990/1000], Loss: 1.42684056\n",
      "TE: Y->X 0.0001\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41404826\n",
      "Epoch [990/1000], Loss: 1.02007094\n",
      "TE: X->Y 0.3936\n",
      "Epoch [990/1000], Loss: 1.41314144\n",
      "Epoch [990/1000], Loss: 1.41311533\n",
      "TE: Y->X 0.0\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.84040404\n",
      "Epoch [990/1000], Loss: 2.02704094\n",
      "TE: X->Y 0.8122\n",
      "Epoch [990/1000], Loss: 2.83887945\n",
      "Epoch [990/1000], Loss: 2.83850566\n",
      "TE: Y->X 0.0004\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.82947083\n",
      "Epoch [990/1000], Loss: 2.03593092\n",
      "TE: X->Y 0.7928\n",
      "Epoch [990/1000], Loss: 2.82824975\n",
      "Epoch [990/1000], Loss: 2.82823088\n",
      "TE: Y->X 0.0\n",
      "## Dim =  3 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 4.24874928\n",
      "Epoch [990/1000], Loss: 3.01503961\n",
      "TE: X->Y 1.2306\n",
      "Epoch [990/1000], Loss: 4.25011085\n",
      "Epoch [990/1000], Loss: 4.24880028\n",
      "TE: Y->X 0.0012\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 4.24808102\n",
      "Epoch [990/1000], Loss: 3.06611585\n",
      "TE: X->Y 1.1824\n",
      "Epoch [990/1000], Loss: 4.24841647\n",
      "Epoch [990/1000], Loss: 4.24827358\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  4 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 5.67283076\n",
      "Epoch [990/1000], Loss: 3.98301432\n",
      "TE: X->Y 1.6853\n",
      "Epoch [990/1000], Loss: 5.66944328\n",
      "Epoch [990/1000], Loss: 5.66585939\n",
      "TE: Y->X 0.0031\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 5.66480894\n",
      "Epoch [990/1000], Loss: 4.08750066\n",
      "TE: X->Y 1.5757\n",
      "Epoch [990/1000], Loss: 5.66598021\n",
      "Epoch [990/1000], Loss: 5.66573369\n",
      "TE: Y->X 0.0002\n",
      "## Dim =  5 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 7.07960398\n",
      "Epoch [990/1000], Loss: 4.91346738\n",
      "TE: X->Y 2.1557\n",
      "Epoch [990/1000], Loss: 7.07538834\n",
      "Epoch [990/1000], Loss: 7.06683086\n",
      "TE: Y->X 0.0072\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 7.07852684\n",
      "Epoch [990/1000], Loss: 5.12351362\n",
      "TE: X->Y 1.953\n",
      "Epoch [990/1000], Loss: 7.08607128\n",
      "Epoch [990/1000], Loss: 7.08579646\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  6 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 8.46814754\n",
      "Epoch [990/1000], Loss: 5.80924761\n",
      "TE: X->Y 2.6419\n",
      "Epoch [990/1000], Loss: 8.46550177\n",
      "Epoch [990/1000], Loss: 8.44537437\n",
      "TE: Y->X 0.017\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 8.49614689\n",
      "Epoch [990/1000], Loss: 6.13639975\n",
      "TE: X->Y 2.3547\n",
      "Epoch [990/1000], Loss: 8.50678999\n",
      "Epoch [990/1000], Loss: 8.50644835\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  7 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 9.86452387\n",
      "Epoch [990/1000], Loss: 6.63360889\n",
      "TE: X->Y 3.2042\n",
      "Epoch [990/1000], Loss: 9.87277172\n",
      "Epoch [990/1000], Loss: 9.85223297\n",
      "TE: Y->X 0.0164\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 9.91411446\n",
      "Epoch [990/1000], Loss: 7.15125634\n",
      "TE: X->Y 2.758\n",
      "Epoch [990/1000], Loss: 9.92374285\n",
      "Epoch [990/1000], Loss: 9.92324222\n",
      "TE: Y->X 0.0005\n",
      "## Dim =  8 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 11.24893046\n",
      "Epoch [990/1000], Loss: 7.46828803\n",
      "TE: X->Y 3.7457\n",
      "Epoch [990/1000], Loss: 11.25460069\n",
      "Epoch [990/1000], Loss: 11.22320941\n",
      "TE: Y->X 0.0225\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 11.33397206\n",
      "Epoch [990/1000], Loss: 8.19830493\n",
      "TE: X->Y 3.1288\n",
      "Epoch [990/1000], Loss: 11.34363496\n",
      "Epoch [990/1000], Loss: 11.34332612\n",
      "TE: Y->X 0.0003\n",
      "## Dim =  9 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 12.61260517\n",
      "Epoch [990/1000], Loss: 8.120282929\n",
      "TE: X->Y 4.4419\n",
      "Epoch [990/1000], Loss: 12.65663846\n",
      "Epoch [990/1000], Loss: 12.54134049\n",
      "TE: Y->X 0.0977\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 12.75014239\n",
      "Epoch [990/1000], Loss: 9.227242852\n",
      "TE: X->Y 3.5149\n",
      "Epoch [990/1000], Loss: 12.76110401\n",
      "Epoch [990/1000], Loss: 12.76030914\n",
      "TE: Y->X 0.0008\n",
      "## Dim =  10 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 13.95782921\n",
      "Epoch [990/1000], Loss: 8.753501058\n",
      "TE: X->Y 5.1437\n",
      "Epoch [990/1000], Loss: 13.98036927\n",
      "Epoch [990/1000], Loss: 13.89653035\n",
      "TE: Y->X 0.0612\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 14.16782957\n",
      "Epoch [990/1000], Loss: 10.21620818\n",
      "TE: X->Y 3.943\n",
      "Epoch [990/1000], Loss: 14.17788137\n",
      "Epoch [990/1000], Loss: 14.17697161\n",
      "TE: Y->X 0.0009\n",
      "\n",
      "### REPLICATE 4/5 ###\n",
      "\n",
      "## Dim =  1 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 1.41519911\n",
      "Epoch [990/1000], Loss: 1.01755467\n",
      "TE: X->Y 0.3973\n",
      "Epoch [990/1000], Loss: 1.41298305\n",
      "Epoch [990/1000], Loss: 1.41278695\n",
      "TE: Y->X 0.0002\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 1.41542994\n",
      "Epoch [990/1000], Loss: 1.01238946\n",
      "TE: X->Y 0.4028\n",
      "Epoch [990/1000], Loss: 1.41518714\n",
      "Epoch [990/1000], Loss: 1.41512385\n",
      "TE: Y->X 0.0001\n",
      "## Dim =  2 #\n",
      "# Sample size =  10000 #\n",
      "Epoch [990/1000], Loss: 2.82509824\n",
      "Epoch [990/1000], Loss: 2.01615843\n",
      "TE: X->Y 0.8077\n",
      "Epoch [990/1000], Loss: 2.82494933\n",
      "Epoch [990/1000], Loss: 2.82431839\n",
      "TE: Y->X 0.0006\n",
      "# Sample size =  100000 #\n",
      "Epoch [990/1000], Loss: 2.83414381\n",
      "Epoch [990/1000], Loss: 2.03133808\n",
      "TE: X->Y 0.8019\n",
      "Epoch [990/1000], Loss: 2.83535053\n",
      "Epoch [460/1000], Loss: 2.83532696\r"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[29], line 16\u001b[0m\n\u001b[1;32m     14\u001b[0m             jp_results_TE_X2Y\u001b[38;5;241m.\u001b[39mwrite(method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124magmte\u001b[39m\u001b[38;5;124m'\u001b[39m, n_dim\u001b[38;5;241m=\u001b[39mdim, sample_size\u001b[38;5;241m=\u001b[39msamples, value\u001b[38;5;241m=\u001b[39mTE_X2Y)\n\u001b[1;32m     15\u001b[0m             \u001b[38;5;66;03m# Estimate Y -> X\u001b[39;00m\n\u001b[0;32m---> 16\u001b[0m             TE_Y2X \u001b[38;5;241m=\u001b[39m TE_agmte(dataset, compute_device, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mY\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mX\u001b[39m\u001b[38;5;124m'\u001b[39m, dim\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m16\u001b[39m, NB)\n\u001b[1;32m     17\u001b[0m             jp_results_TE_Y2X\u001b[38;5;241m.\u001b[39mwrite(method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124magmte\u001b[39m\u001b[38;5;124m'\u001b[39m, n_dim\u001b[38;5;241m=\u001b[39mdim, sample_size\u001b[38;5;241m=\u001b[39msamples, value\u001b[38;5;241m=\u001b[39mTE_Y2X)\n\u001b[1;32m     19\u001b[0m jp_results_TE_X2Y\u001b[38;5;241m.\u001b[39mdf\u001b[38;5;241m.\u001b[39mto_csv(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mresults/agmte/jp_results_TE_X2Y_dim.csv\u001b[39m\u001b[38;5;124m'\u001b[39m, index\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
      "Cell \u001b[0;32mIn[8], line 11\u001b[0m, in \u001b[0;36mTE_agmte\u001b[0;34m(dataset, device, var_from, var_to, hidden_size, batch_size, plot_loss)\u001b[0m\n\u001b[1;32m      8\u001b[0m tl_model_1, loss_1 \u001b[38;5;241m=\u001b[39m _train_agm(tl_model_1, tl_dataloader_1,\n\u001b[1;32m      9\u001b[0m                                 batch_size\u001b[38;5;241m=\u001b[39mbatch_size, epochs\u001b[38;5;241m=\u001b[39mEPOCHS, learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.01\u001b[39m, optimize\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msgd\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28mprint\u001b[39m()\n\u001b[0;32m---> 11\u001b[0m tl_model_2, loss_2 \u001b[38;5;241m=\u001b[39m _train_agm(tl_model_2, tl_dataloader_2, \n\u001b[1;32m     12\u001b[0m                                 batch_size\u001b[38;5;241m=\u001b[39mbatch_size, epochs\u001b[38;5;241m=\u001b[39mEPOCHS, learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.01\u001b[39m, optimize\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msgd\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28mprint\u001b[39m()\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plot_loss:\n",
      "File \u001b[0;32m~/Cloud/Projects/2023_UCL_Causal_Direction/AGM-TE/agm_te/model.py:332\u001b[0m, in \u001b[0;36m_train_agm\u001b[0;34m(model, data, batch_size, epochs, learning_rate, lr_decay_step, lr_decay_gamma, optimize, l2_penalty)\u001b[0m\n\u001b[1;32m    330\u001b[0m \u001b[38;5;66;03m# get the gradients\u001b[39;00m\n\u001b[1;32m    331\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m--> 332\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m    333\u001b[0m \u001b[38;5;66;03m# update the weights\u001b[39;00m\n\u001b[1;32m    334\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.11/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    513\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    514\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    515\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    520\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    521\u001b[0m     )\n\u001b[0;32m--> 522\u001b[0m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mbackward(\n\u001b[1;32m    523\u001b[0m     \u001b[38;5;28mself\u001b[39m, gradient, retain_graph, create_graph, inputs\u001b[38;5;241m=\u001b[39minputs\n\u001b[1;32m    524\u001b[0m )\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py:266\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    261\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    263\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    264\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    265\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 266\u001b[0m Variable\u001b[38;5;241m.\u001b[39m_execution_engine\u001b[38;5;241m.\u001b[39mrun_backward(  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    267\u001b[0m     tensors,\n\u001b[1;32m    268\u001b[0m     grad_tensors_,\n\u001b[1;32m    269\u001b[0m     retain_graph,\n\u001b[1;32m    270\u001b[0m     create_graph,\n\u001b[1;32m    271\u001b[0m     inputs,\n\u001b[1;32m    272\u001b[0m     allow_unreachable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    273\u001b[0m     accumulate_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    274\u001b[0m )\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "jp_results_TE_X2Y = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "jp_results_TE_Y2X = Results(columns=['method', 'n_dim', 'sample_size'])\n",
    "\n",
    "for r in range(REPLICATES):\n",
    "    print(f\"\\n### REPLICATE {r+1}/{REPLICATES} ###\\n\")\n",
    "    for dim, generator in zip(dim_range, jp_generator_lst):\n",
    "        print(\"## Dim = \", dim, \"#\")\n",
    "        for samples in sample_sizes:\n",
    "            print(\"# Sample size = \", samples, \"#\")\n",
    "            # Simulate data\n",
    "            dataset = get_dataset(generator, samples, int(np.round(samples/20)), seed=r)\n",
    "            # Estimate X -> Y\n",
    "            TE_X2Y = TE_agmte(dataset, compute_device, 'X', 'Y', dim*16, NB)\n",
    "            jp_results_TE_X2Y.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_X2Y)\n",
    "            # Estimate Y -> X\n",
    "            TE_Y2X = TE_agmte(dataset, compute_device, 'Y', 'X', dim*16, NB)\n",
    "            jp_results_TE_Y2X.write(method='agmte', n_dim=dim, sample_size=samples, value=TE_Y2X)\n",
    "\n",
    "jp_results_TE_X2Y.df.to_csv('results/agmte/jp_results_TE_X2Y_dim.csv', index=False)\n",
    "jp_results_TE_Y2X.df.to_csv('results/agmte/jp_results_TE_Y2X_dim.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
