{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "626550fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from openhawkes.simulation import ExpHawkes\n",
    "from openhawkes.inference    import fit_hawkes\n",
    "from openhawkes.utils        import generate_alphas_with_features\n",
    "from tqdm import tqdm\n",
    "import itertools, math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b21050d",
   "metadata": {},
   "source": [
    "### xP main text (errors with features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b288bab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sequences:   0%|          | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[   0] loss=734.1300\n",
      "[ 100] loss=449.9457\n",
      "[ 200] loss=441.5269\n",
      "[ 300] loss=437.6582\n",
      "[ 400] loss=436.4583\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1539.4879\n",
      "[ 100] loss=1156.2527\n",
      "[ 200] loss=1146.0522\n",
      "[ 300] loss=1142.0020\n",
      "[ 400] loss=1140.0957\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=2049.1826\n",
      "[ 100] loss=1646.1121\n",
      "[ 200] loss=1642.4673\n",
      "[ 300] loss=1642.0477\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1772.8820\n",
      "[ 100] loss=1430.9990\n",
      "[ 200] loss=1425.2699\n",
      "[ 300] loss=1422.6846\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1837.5553\n",
      "[ 100] loss=1476.8821\n",
      "[ 200] loss=1466.4723\n",
      "[ 300] loss=1463.2784\n",
      "[ 400] loss=1462.8073\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1830.9020\n",
      "[ 100] loss=1431.3071\n",
      "[ 200] loss=1426.6941\n",
      "[ 300] loss=1425.6832\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1599.3304\n",
      "[ 100] loss=1300.8890\n",
      "[ 200] loss=1297.3369\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=961.7937\n",
      "[ 100] loss=712.5977\n",
      "[ 200] loss=705.8531\n",
      "[ 300] loss=704.1980\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1128.2795\n",
      "[ 100] loss=837.1211\n",
      "[ 200] loss=829.9998\n",
      "[ 300] loss=828.2171\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=2044.1465\n",
      "[ 100] loss=1634.1968\n",
      "[ 200] loss=1625.8827\n",
      "[ 300] loss=1623.7793\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1878.3291\n",
      "[ 100] loss=1523.6699\n",
      "[ 200] loss=1517.4873\n",
      "[ 300] loss=1516.7321\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1527.2897\n",
      "[ 100] loss=1145.1627\n",
      "[ 200] loss=1137.3669\n",
      "[ 300] loss=1135.6428\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1261.8783\n",
      "[ 100] loss=984.1727\n",
      "[ 200] loss=976.6704\n",
      "[ 300] loss=975.0809\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1249.7029\n",
      "[ 100] loss=986.4651\n",
      "[ 200] loss=982.2550\n",
      "[ 300] loss=980.1289\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1694.3447\n",
      "[ 100] loss=1318.8104\n",
      "[ 200] loss=1313.0444\n",
      "[ 300] loss=1310.7639\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=313.2814\n",
      "[ 100] loss=21.4373\n",
      "[ 200] loss=10.2814\n",
      "[ 300] loss=3.4806\n",
      "[ 400] loss=-0.1989\n",
      "[ 500] loss=-1.8679\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=993.5511\n",
      "[ 100] loss=704.9542\n",
      "[ 200] loss=696.8025\n",
      "[ 300] loss=694.6746\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=649.2051\n",
      "[ 100] loss=369.5949\n",
      "[ 200] loss=363.6772\n",
      "[ 300] loss=361.2128\n",
      "[ 400] loss=358.7071\n",
      "[ 500] loss=351.6906\n",
      "[ 600] loss=97.2291\n",
      "[ 700] loss=-5744.9844\n",
      "[ 800] loss=-5901.4570\n",
      "[ 900] loss=-5933.2910\n",
      "[1000] loss=-5951.1436\n",
      "[1100] loss=-5963.9438\n",
      "[1200] loss=-5974.5273\n",
      "[1300] loss=-5984.0957\n",
      "[1400] loss=-5993.2539\n",
      "[1500] loss=-6002.3442\n",
      "[1600] loss=-6011.5537\n",
      "[1700] loss=-6020.8945\n",
      "[1800] loss=-6030.0649\n",
      "[1900] loss=-6038.2588\n",
      "[2000] loss=-6044.1284\n",
      "[2100] loss=-6047.0771\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1871.2799\n",
      "[ 100] loss=1492.3384\n",
      "[ 200] loss=1484.8850\n",
      "[ 300] loss=1481.7837\n",
      "[ 400] loss=1480.1156\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=1174.4514\n",
      "[ 100] loss=816.7811\n",
      "[ 200] loss=813.3511\n",
      "[ 300] loss=811.8203\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sequences:  20%|██        | 1/5 [01:03<04:15, 63.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5587.6953\n",
      "[ 100] loss=4320.5547\n",
      "[ 200] loss=4295.6079\n",
      "[ 300] loss=4289.0029\n",
      "[ 400] loss=4287.0190\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=4380.4844\n",
      "[ 100] loss=3122.9600\n",
      "[ 200] loss=3098.3584\n",
      "[ 300] loss=3090.5459\n",
      "[ 400] loss=3087.9561\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=4678.0957\n",
      "[ 100] loss=3354.5952\n",
      "[ 200] loss=3327.8301\n",
      "[ 300] loss=3318.5676\n",
      "[ 400] loss=3314.7373\n",
      "[ 500] loss=3313.4614\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5446.7896\n",
      "[ 100] loss=4112.5884\n",
      "[ 200] loss=4098.9360\n",
      "[ 300] loss=4092.7854\n",
      "[ 400] loss=4089.1777\n",
      "[ 500] loss=4087.4585\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5293.9355\n",
      "[ 100] loss=3872.5752\n",
      "[ 200] loss=3853.7544\n",
      "[ 300] loss=3847.7300\n",
      "[ 400] loss=3844.9980\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5070.6416\n",
      "[ 100] loss=3826.7959\n",
      "[ 200] loss=3809.5317\n",
      "[ 300] loss=3801.6982\n",
      "[ 400] loss=3797.0278\n",
      "[ 500] loss=3794.3042\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=6319.7305\n",
      "[ 100] loss=5012.1069\n",
      "[ 200] loss=4986.2109\n",
      "[ 300] loss=4978.4033\n",
      "[ 400] loss=4976.0430\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5324.1465\n",
      "[ 100] loss=4107.0498\n",
      "[ 200] loss=4084.5947\n",
      "[ 300] loss=4078.8516\n",
      "[ 400] loss=4076.1240\n",
      "[ 500] loss=4074.8743\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5351.2354\n",
      "[ 100] loss=4116.9272\n",
      "[ 200] loss=4099.0781\n",
      "[ 300] loss=4094.6943\n",
      "[ 400] loss=4093.5044\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5408.5615\n",
      "[ 100] loss=4084.5657\n",
      "[ 200] loss=4049.8315\n",
      "[ 300] loss=4040.2378\n",
      "[ 400] loss=4036.6367\n",
      "[ 500] loss=4035.5178\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5057.9512\n",
      "[ 100] loss=3675.8013\n",
      "[ 200] loss=3649.8987\n",
      "[ 300] loss=3642.7744\n",
      "[ 400] loss=3640.2939\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=6350.3018\n",
      "[ 100] loss=4976.9111\n",
      "[ 200] loss=4955.5586\n",
      "[ 300] loss=4951.0220\n",
      "[ 400] loss=4948.8945\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5561.6528\n",
      "[ 100] loss=4254.4092\n",
      "[ 200] loss=4230.5059\n",
      "[ 300] loss=4223.9170\n",
      "[ 400] loss=4221.2876\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=4976.6460\n",
      "[ 100] loss=3807.9766\n",
      "[ 200] loss=3787.5012\n",
      "[ 300] loss=3781.0166\n",
      "[ 400] loss=3777.8779\n",
      "[ 500] loss=3776.4773\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=4465.6572\n",
      "[ 100] loss=3239.3843\n",
      "[ 200] loss=3221.8318\n",
      "[ 300] loss=3215.7568\n",
      "[ 400] loss=3212.8726\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=4559.5825\n",
      "[ 100] loss=3298.5513\n",
      "[ 200] loss=3267.2744\n",
      "[ 300] loss=3255.1929\n",
      "[ 400] loss=3248.7480\n",
      "[ 500] loss=3245.2847\n",
      "[ 600] loss=3243.8958\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=4733.0781\n",
      "[ 100] loss=3534.2231\n",
      "[ 200] loss=3508.3518\n",
      "[ 300] loss=3500.0474\n",
      "[ 400] loss=3495.9006\n",
      "[ 500] loss=3493.8875\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5096.6191\n",
      "[ 100] loss=3837.8833\n",
      "[ 200] loss=3809.7664\n",
      "[ 300] loss=3802.8250\n",
      "[ 400] loss=3800.7856\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5911.8345\n",
      "[ 100] loss=4515.8770\n",
      "[ 200] loss=4491.7861\n",
      "[ 300] loss=4483.6816\n",
      "[ 400] loss=4479.6851\n",
      "[ 500] loss=4477.6377\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=5297.2061\n",
      "[ 100] loss=4035.4692\n",
      "[ 200] loss=4018.1838\n",
      "[ 300] loss=4013.7119\n",
      "[ 400] loss=4011.8342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sequences:  40%|████      | 2/5 [02:39<04:08, 82.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=14734.0273\n",
      "[ 100] loss=11422.9834\n",
      "[ 200] loss=11352.1895\n",
      "[ 300] loss=11332.2080\n",
      "[ 400] loss=11325.3574\n",
      "[ 500] loss=11322.2656\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=13478.2852\n",
      "[ 100] loss=10232.0840\n",
      "[ 200] loss=10177.8721\n",
      "[ 300] loss=10164.1064\n",
      "[ 400] loss=10157.7705\n",
      "[ 500] loss=10154.1162\n",
      "[ 600] loss=10152.5273\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=13821.1797\n",
      "[ 100] loss=10787.0537\n",
      "[ 200] loss=10722.2939\n",
      "[ 300] loss=10707.8682\n",
      "[ 400] loss=10703.8613\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=16100.7031\n",
      "[ 100] loss=12556.3066\n",
      "[ 200] loss=12503.7334\n",
      "[ 300] loss=12487.7188\n",
      "[ 400] loss=12480.6895\n",
      "[ 500] loss=12477.3818\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=12786.4785\n",
      "[ 100] loss=9402.2070\n",
      "[ 200] loss=9347.0254\n",
      "[ 300] loss=9332.1172\n",
      "[ 400] loss=9325.2988\n",
      "[ 500] loss=9321.7969\n",
      "[ 600] loss=9320.4766\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=13385.7490\n",
      "[ 100] loss=10283.7041\n",
      "[ 200] loss=10225.3145\n",
      "[ 300] loss=10207.2900\n",
      "[ 400] loss=10200.1680\n",
      "[ 500] loss=10196.6758\n",
      "[ 600] loss=10195.3154\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=14341.0312\n",
      "[ 100] loss=11008.5635\n",
      "[ 200] loss=10945.5400\n",
      "[ 300] loss=10927.6709\n",
      "[ 400] loss=10920.0684\n",
      "[ 500] loss=10916.0850\n",
      "[ 600] loss=10914.0918\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=13463.4551\n",
      "[ 100] loss=10153.9473\n",
      "[ 200] loss=10090.4785\n",
      "[ 300] loss=10076.8623\n",
      "[ 400] loss=10071.5352\n",
      "[ 500] loss=10068.6426\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=11510.9346\n",
      "[ 100] loss=8565.2266\n",
      "[ 200] loss=8505.2520\n",
      "[ 300] loss=8487.4082\n",
      "[ 400] loss=8478.8184\n",
      "[ 500] loss=8473.7324\n",
      "[ 600] loss=8470.9150\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=12597.4141\n",
      "[ 100] loss=9539.8789\n",
      "[ 200] loss=9485.4473\n",
      "[ 300] loss=9471.3359\n",
      "[ 400] loss=9465.3086\n",
      "[ 500] loss=9461.9375\n",
      "[ 600] loss=9460.1992\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=15130.2412\n",
      "[ 100] loss=11594.4219\n",
      "[ 200] loss=11537.8945\n",
      "[ 300] loss=11521.3047\n",
      "[ 400] loss=11512.6426\n",
      "[ 500] loss=11507.2441\n",
      "[ 600] loss=11503.7754\n",
      "[ 700] loss=11501.6367\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=17381.2070\n",
      "[ 100] loss=13659.2197\n",
      "[ 200] loss=13593.5996\n",
      "[ 300] loss=13579.8799\n",
      "[ 400] loss=13575.0879\n",
      "[ 500] loss=13572.4229\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=11883.6592\n",
      "[ 100] loss=8679.6064\n",
      "[ 200] loss=8607.7363\n",
      "[ 300] loss=8589.5225\n",
      "[ 400] loss=8582.5391\n",
      "[ 500] loss=8579.2617\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=13036.0488\n",
      "[ 100] loss=10024.4795\n",
      "[ 200] loss=9967.8535\n",
      "[ 300] loss=9950.9189\n",
      "[ 400] loss=9943.8613\n",
      "[ 500] loss=9940.1094\n",
      "[ 600] loss=9938.2832\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=11377.5391\n",
      "[ 100] loss=8345.2900\n",
      "[ 200] loss=8287.7725\n",
      "[ 300] loss=8269.0225\n",
      "[ 400] loss=8259.9258\n",
      "[ 500] loss=8254.9111\n",
      "[ 600] loss=8252.0645\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=14180.9199\n",
      "[ 100] loss=10927.3105\n",
      "[ 200] loss=10872.3359\n",
      "[ 300] loss=10857.9004\n",
      "[ 400] loss=10852.8086\n",
      "[ 500] loss=10850.6562\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=14817.0859\n",
      "[ 100] loss=11489.8984\n",
      "[ 200] loss=11438.1211\n",
      "[ 300] loss=11419.8730\n",
      "[ 400] loss=11409.8770\n",
      "[ 500] loss=11403.8936\n",
      "[ 600] loss=11400.4102\n",
      "[ 700] loss=11398.8428\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=13858.5537\n",
      "[ 100] loss=10684.6367\n",
      "[ 200] loss=10620.8525\n",
      "[ 300] loss=10602.0879\n",
      "[ 400] loss=10594.0176\n",
      "[ 500] loss=10589.5996\n",
      "[ 600] loss=10587.1074\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=15708.2402\n",
      "[ 100] loss=12203.1113\n",
      "[ 200] loss=12145.3203\n",
      "[ 300] loss=12130.6641\n",
      "[ 400] loss=12124.8564\n",
      "[ 500] loss=12121.8604\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=12006.0928\n",
      "[ 100] loss=8841.4570\n",
      "[ 200] loss=8786.1523\n",
      "[ 300] loss=8773.6152\n",
      "[ 400] loss=8768.4707\n",
      "[ 500] loss=8766.0879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sequences:  60%|██████    | 3/5 [06:00<04:32, 136.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=22166.7500\n",
      "[ 100] loss=17499.8633\n",
      "[ 200] loss=17411.1172\n",
      "[ 300] loss=17386.7793\n",
      "[ 400] loss=17377.1680\n",
      "[ 500] loss=17371.9688\n",
      "[ 600] loss=17368.9102\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=19020.6953\n",
      "[ 100] loss=14519.7920\n",
      "[ 200] loss=14431.4834\n",
      "[ 300] loss=14403.5137\n",
      "[ 400] loss=14391.2002\n",
      "[ 500] loss=14385.1475\n",
      "[ 600] loss=14382.1074\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=16669.8477\n",
      "[ 100] loss=12390.9521\n",
      "[ 200] loss=12308.2549\n",
      "[ 300] loss=12286.1504\n",
      "[ 400] loss=12277.6191\n",
      "[ 500] loss=12273.3350\n",
      "[ 600] loss=12271.3623\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=19315.1641\n",
      "[ 100] loss=14834.0195\n",
      "[ 200] loss=14742.8350\n",
      "[ 300] loss=14719.8203\n",
      "[ 400] loss=14710.8340\n",
      "[ 500] loss=14706.0732\n",
      "[ 600] loss=14703.5449\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=18585.2031\n",
      "[ 100] loss=14233.6211\n",
      "[ 200] loss=14141.4199\n",
      "[ 300] loss=14115.7090\n",
      "[ 400] loss=14105.5361\n",
      "[ 500] loss=14100.2363\n",
      "[ 600] loss=14097.2031\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=17581.4766\n",
      "[ 100] loss=13273.7188\n",
      "[ 200] loss=13184.4951\n",
      "[ 300] loss=13157.3213\n",
      "[ 400] loss=13145.5449\n",
      "[ 500] loss=13139.1680\n",
      "[ 600] loss=13135.6367\n",
      "[ 700] loss=13134.2090\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=21055.9648\n",
      "[ 100] loss=16419.1016\n",
      "[ 200] loss=16324.7051\n",
      "[ 300] loss=16301.9111\n",
      "[ 400] loss=16293.4971\n",
      "[ 500] loss=16289.2373\n",
      "[ 600] loss=16286.9658\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=21632.9492\n",
      "[ 100] loss=16769.9258\n",
      "[ 200] loss=16688.0820\n",
      "[ 300] loss=16669.1406\n",
      "[ 400] loss=16659.3008\n",
      "[ 500] loss=16652.8633\n",
      "[ 600] loss=16648.6699\n",
      "[ 700] loss=16645.9609\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=19218.2891\n",
      "[ 100] loss=14741.9883\n",
      "[ 200] loss=14662.0107\n",
      "[ 300] loss=14643.9043\n",
      "[ 400] loss=14635.5684\n",
      "[ 500] loss=14630.5322\n",
      "[ 600] loss=14627.4639\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=17204.9062\n",
      "[ 100] loss=12657.2305\n",
      "[ 200] loss=12565.4307\n",
      "[ 300] loss=12543.2373\n",
      "[ 400] loss=12533.7793\n",
      "[ 500] loss=12528.3516\n",
      "[ 600] loss=12525.1299\n",
      "[ 700] loss=12523.5557\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=20986.3418\n",
      "[ 100] loss=16201.0186\n",
      "[ 200] loss=16114.6260\n",
      "[ 300] loss=16089.7363\n",
      "[ 400] loss=16076.5498\n",
      "[ 500] loss=16068.3652\n",
      "[ 600] loss=16063.1152\n",
      "[ 700] loss=16059.7520\n",
      "[ 800] loss=16057.7812\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=20327.8242\n",
      "[ 100] loss=15669.0801\n",
      "[ 200] loss=15595.5010\n",
      "[ 300] loss=15579.5928\n",
      "[ 400] loss=15572.6904\n",
      "[ 500] loss=15568.5518\n",
      "[ 600] loss=15566.0020\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=19328.6953\n",
      "[ 100] loss=14909.6592\n",
      "[ 200] loss=14821.4971\n",
      "[ 300] loss=14796.7168\n",
      "[ 400] loss=14785.9062\n",
      "[ 500] loss=14780.1123\n",
      "[ 600] loss=14776.8828\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=18276.1387\n",
      "[ 100] loss=13863.9453\n",
      "[ 200] loss=13784.0225\n",
      "[ 300] loss=13762.4863\n",
      "[ 400] loss=13752.5918\n",
      "[ 500] loss=13747.1875\n",
      "[ 600] loss=13744.2754\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=17530.8828\n",
      "[ 100] loss=13033.3320\n",
      "[ 200] loss=12937.2861\n",
      "[ 300] loss=12908.7686\n",
      "[ 400] loss=12896.0996\n",
      "[ 500] loss=12889.6045\n",
      "[ 600] loss=12885.9277\n",
      "[ 700] loss=12883.9648\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=19071.8164\n",
      "[ 100] loss=14739.4980\n",
      "[ 200] loss=14658.2539\n",
      "[ 300] loss=14634.4922\n",
      "[ 400] loss=14623.6270\n",
      "[ 500] loss=14617.5410\n",
      "[ 600] loss=14614.0586\n",
      "[ 700] loss=14612.5381\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=18914.0625\n",
      "[ 100] loss=14399.0762\n",
      "[ 200] loss=14297.2637\n",
      "[ 300] loss=14268.6875\n",
      "[ 400] loss=14258.1016\n",
      "[ 500] loss=14253.3281\n",
      "[ 600] loss=14251.1914\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=18657.5527\n",
      "[ 100] loss=14242.8428\n",
      "[ 200] loss=14146.8135\n",
      "[ 300] loss=14121.1689\n",
      "[ 400] loss=14109.4912\n",
      "[ 500] loss=14102.7598\n",
      "[ 600] loss=14098.7744\n",
      "[ 700] loss=14096.4951\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=20888.1797\n",
      "[ 100] loss=16069.4297\n",
      "[ 200] loss=15997.8115\n",
      "[ 300] loss=15979.9414\n",
      "[ 400] loss=15970.5068\n",
      "[ 500] loss=15964.4668\n",
      "[ 600] loss=15960.7998\n",
      "[ 700] loss=15958.9434\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=16877.0117\n",
      "[ 100] loss=12342.5840\n",
      "[ 200] loss=12260.1914\n",
      "[ 300] loss=12241.6221\n",
      "[ 400] loss=12233.2441\n",
      "[ 500] loss=12228.6348\n",
      "[ 600] loss=12226.0469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sequences:  80%|████████  | 4/5 [10:48<03:16, 196.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=28728.4883\n",
      "[ 100] loss=22300.4766\n",
      "[ 200] loss=22175.6699\n",
      "[ 300] loss=22143.1289\n",
      "[ 400] loss=22129.6152\n",
      "[ 500] loss=22122.0742\n",
      "[ 600] loss=22117.7695\n",
      "[ 700] loss=22115.5039\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=27127.4609\n",
      "[ 100] loss=20615.8984\n",
      "[ 200] loss=20492.9492\n",
      "[ 300] loss=20458.5586\n",
      "[ 400] loss=20443.1816\n",
      "[ 500] loss=20434.8828\n",
      "[ 600] loss=20430.3301\n",
      "[ 700] loss=20427.7461\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=26430.8633\n",
      "[ 100] loss=19829.8008\n",
      "[ 200] loss=19674.8125\n",
      "[ 300] loss=19631.7500\n",
      "[ 400] loss=19617.6016\n",
      "[ 500] loss=19611.0977\n",
      "[ 600] loss=19607.5703\n",
      "[ 700] loss=19605.8887\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=25859.9180\n",
      "[ 100] loss=19440.3750\n",
      "[ 200] loss=19322.9160\n",
      "[ 300] loss=19294.0312\n",
      "[ 400] loss=19279.1445\n",
      "[ 500] loss=19269.5039\n",
      "[ 600] loss=19263.2109\n",
      "[ 700] loss=19259.1484\n",
      "[ 800] loss=19256.5078\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=26393.5156\n",
      "[ 100] loss=20010.9570\n",
      "[ 200] loss=19900.6855\n",
      "[ 300] loss=19869.5723\n",
      "[ 400] loss=19854.3672\n",
      "[ 500] loss=19845.7227\n",
      "[ 600] loss=19840.7461\n",
      "[ 700] loss=19837.9824\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=26324.2891\n",
      "[ 100] loss=20135.5547\n",
      "[ 200] loss=20023.1074\n",
      "[ 300] loss=19986.4219\n",
      "[ 400] loss=19970.6973\n",
      "[ 500] loss=19962.6758\n",
      "[ 600] loss=19958.4766\n",
      "[ 700] loss=19956.5977\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=26998.0312\n",
      "[ 100] loss=20609.5156\n",
      "[ 200] loss=20482.5898\n",
      "[ 300] loss=20444.2227\n",
      "[ 400] loss=20427.3887\n",
      "[ 500] loss=20418.6172\n",
      "[ 600] loss=20413.9258\n",
      "[ 700] loss=20411.5215\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=29239.3008\n",
      "[ 100] loss=22473.6973\n",
      "[ 200] loss=22351.2773\n",
      "[ 300] loss=22326.0664\n",
      "[ 400] loss=22315.9785\n",
      "[ 500] loss=22310.2383\n",
      "[ 600] loss=22306.6016\n",
      "[ 700] loss=22304.2930\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=26039.1719\n",
      "[ 100] loss=19569.1914\n",
      "[ 200] loss=19443.8828\n",
      "[ 300] loss=19408.8320\n",
      "[ 400] loss=19391.8770\n",
      "[ 500] loss=19382.8203\n",
      "[ 600] loss=19378.1680\n",
      "[ 700] loss=19375.9922\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=29830.3887\n",
      "[ 100] loss=23141.2930\n",
      "[ 200] loss=23017.1543\n",
      "[ 300] loss=22987.3438\n",
      "[ 400] loss=22975.7012\n",
      "[ 500] loss=22968.6270\n",
      "[ 600] loss=22963.9473\n",
      "[ 700] loss=22960.8398\n",
      "[ 800] loss=22959.0000\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=28614.3008\n",
      "[ 100] loss=21897.5527\n",
      "[ 200] loss=21772.7832\n",
      "[ 300] loss=21739.8477\n",
      "[ 400] loss=21722.7695\n",
      "[ 500] loss=21711.8320\n",
      "[ 600] loss=21704.6855\n",
      "[ 700] loss=21700.1289\n",
      "[ 800] loss=21697.3047\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=28014.7305\n",
      "[ 100] loss=21345.4141\n",
      "[ 200] loss=21224.6602\n",
      "[ 300] loss=21196.4414\n",
      "[ 400] loss=21182.2598\n",
      "[ 500] loss=21173.4238\n",
      "[ 600] loss=21167.9395\n",
      "[ 700] loss=21164.5664\n",
      "[ 800] loss=21162.8789\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=25223.8398\n",
      "[ 100] loss=18874.2422\n",
      "[ 200] loss=18741.6758\n",
      "[ 300] loss=18706.4766\n",
      "[ 400] loss=18692.7285\n",
      "[ 500] loss=18685.7500\n",
      "[ 600] loss=18681.9961\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=26207.2617\n",
      "[ 100] loss=19991.3047\n",
      "[ 200] loss=19868.8945\n",
      "[ 300] loss=19838.6152\n",
      "[ 400] loss=19824.7480\n",
      "[ 500] loss=19816.8203\n",
      "[ 600] loss=19812.2363\n",
      "[ 700] loss=19809.6055\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=21023.0352\n",
      "[ 100] loss=14760.7354\n",
      "[ 200] loss=14644.9053\n",
      "[ 300] loss=14610.8320\n",
      "[ 400] loss=14597.3945\n",
      "[ 500] loss=14590.9434\n",
      "[ 600] loss=14587.5293\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=29621.6289\n",
      "[ 100] loss=22740.3184\n",
      "[ 200] loss=22646.2031\n",
      "[ 300] loss=22621.1875\n",
      "[ 400] loss=22609.8516\n",
      "[ 500] loss=22603.5469\n",
      "[ 600] loss=22599.7969\n",
      "[ 700] loss=22597.6465\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=23002.1836\n",
      "[ 100] loss=16871.1992\n",
      "[ 200] loss=16758.3516\n",
      "[ 300] loss=16726.5547\n",
      "[ 400] loss=16711.8340\n",
      "[ 500] loss=16703.9805\n",
      "[ 600] loss=16699.8008\n",
      "[ 700] loss=16697.7773\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=25520.4590\n",
      "[ 100] loss=19226.9141\n",
      "[ 200] loss=19102.4062\n",
      "[ 300] loss=19070.1250\n",
      "[ 400] loss=19057.0742\n",
      "[ 500] loss=19050.3105\n",
      "[ 600] loss=19046.6738\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=31245.3027\n",
      "[ 100] loss=24344.9375\n",
      "[ 200] loss=24255.2598\n",
      "[ 300] loss=24229.5781\n",
      "[ 400] loss=24212.9766\n",
      "[ 500] loss=24200.6484\n",
      "[ 600] loss=24191.7656\n",
      "[ 700] loss=24185.5469\n",
      "[ 800] loss=24181.2480\n",
      "[ 900] loss=24178.2812\n",
      "[1000] loss=24176.6602\n",
      "We stop because the loss is not changing anymore.\n",
      "[   0] loss=28882.3594\n",
      "[ 100] loss=22224.4668\n",
      "[ 200] loss=22114.2695\n",
      "[ 300] loss=22081.2148\n",
      "[ 400] loss=22064.0273\n",
      "[ 500] loss=22054.3691\n",
      "[ 600] loss=22049.0859\n",
      "[ 700] loss=22046.2500\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sequences: 100%|██████████| 5/5 [17:49<00:00, 213.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We stop because the loss is not changing anymore.\n",
      "   nb_sequences  mean_err_mu  mean_err_beta  mean_err_B0  mean_err_B1_off  \\\n",
      "0             5     0.008573       0.239008     6.700572         0.925944   \n",
      "1            20     0.001492       0.065501     1.067524         0.101676   \n",
      "2            50     0.001076       0.047723     0.566479         0.053737   \n",
      "3            70     0.000727       0.045181     0.363964         0.027448   \n",
      "4           100     0.000900       0.050596     0.290306         0.021551   \n",
      "\n",
      "   mean_err_B2_off  mean_err_B12_diag  mean_err_total  \n",
      "0         1.206047           2.062818       11.142962  \n",
      "1         0.094911           0.047656        1.378760  \n",
      "2         0.068054           0.020751        0.757821  \n",
      "3         0.044556           0.006294        0.488170  \n",
      "4         0.041234           0.010588        0.415175  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# ------------------ Initialisation ------------------\n",
    "D = 3\n",
    "mu_true    = np.array([0.1,  0.15, 0.2])\n",
    "beta_true  = np.array([0.8,1,1.2])    # β not unique\n",
    "\n",
    "B0 = np.array([\n",
    "    [-1.0, -0.5, -1.0],\n",
    "    [-1.0, -0.5, -1.0],\n",
    "    [-0.5,  0.0,  0.0]\n",
    "])\n",
    "B1 = np.array([\n",
    "    [ 0.0,   0.0,   0.0],\n",
    "    [ 0.0,  -0.75,  0.50],\n",
    "    [ 0.0,  -1.00,  0.0]\n",
    "])\n",
    "B2 = np.array([\n",
    "    [ 0.0,   0.0,   0.75],\n",
    "    [ 0.50,  0.0,  -0.50],\n",
    "    [ 0.75, -0.25,  0.0]\n",
    "])\n",
    "\n",
    "max_events    = 300\n",
    "seq_list     = [5, 20, 50, 70 ,100]\n",
    "num_repeats  = 20\n",
    "\n",
    "results = []\n",
    "\n",
    "# ------------------ Principal loop ------------------\n",
    "for nb_sequences in tqdm(seq_list, desc=\"Sequences\"):\n",
    "    # stock all errors\n",
    "    acc = {\n",
    "        \"err_mu\": [], \"err_beta\": [], \"err_B0\": [],\n",
    "        \"err_B1_off\": [], \"err_B2_off\": [], \"err_B12_diag\": [],\n",
    "        \"err_total\": []\n",
    "    }\n",
    "\n",
    "    for rep in range(num_repeats):\n",
    "        all_times, all_types, all_x = [], [], []\n",
    "        rng = np.random.default_rng(1234 + rep)\n",
    "\n",
    "        # generate and accumulate nb_sequences trajectories\n",
    "        for _ in range(nb_sequences):\n",
    "            alpha_mat, x_vec = generate_alphas_with_features(rng, B0, B1, B2)\n",
    "            proc = ExpHawkes(mu_true, alpha_mat, beta_true)\n",
    "            ev   = proc.simulate(max_events=max_events)\n",
    "\n",
    "            all_times.extend(ev.times)\n",
    "            all_types.extend(ev.types)\n",
    "            all_x.append(x_vec)\n",
    "\n",
    "        # estimation\n",
    "        model = fit_hawkes(\n",
    "            nb_types    = D,\n",
    "            times        = all_times,\n",
    "            types        = all_types,\n",
    "            features     = all_x,\n",
    "            nb_batches  = nb_sequences,\n",
    "            lr           = 0.01,\n",
    "            nb_of_features = 1,\n",
    "            beta_unique  = False\n",
    "        )\n",
    "\n",
    "        # recover estimated parameters\n",
    "        mu_hat   = model.mu.detach().cpu().numpy()\n",
    "        beta_hat = model.beta.detach().cpu().numpy()\n",
    "        B0_hat   = model.gamma.detach().cpu().numpy()\n",
    "        B1_hat   = model.theta1.detach().cpu().numpy()\n",
    "        B2_hat   = model.theta2.detach().cpu().numpy()\n",
    "\n",
    "        # 1) errors on μ, β, B0\n",
    "        err_mu   = np.sum((mu_hat   - mu_true)   ** 2)\n",
    "        err_beta = np.sum((beta_hat - beta_true) ** 2)\n",
    "        err_B0   = np.sum((B0_hat   - B0)        ** 2)\n",
    "\n",
    "        B1_hat = B1_hat.squeeze()      \n",
    "        B2_hat = B2_hat.squeeze()\n",
    "        # 2) off-diagonal errors\n",
    "        mask_off    = ~np.eye(D, dtype=bool)\n",
    "        err_B1_off  = np.sum((B1_hat[mask_off] - B1[mask_off]) ** 2)\n",
    "        err_B2_off  = np.sum((B2_hat[mask_off] - B2[mask_off]) ** 2)\n",
    "\n",
    "        # 3) diagonal errors\n",
    "        true_diag   = np.diag(B1) + np.diag(B2)\n",
    "        est_diag    = np.diag(B1_hat) + np.diag(B2_hat)\n",
    "        err_B12_diag = np.sum((est_diag - true_diag) ** 2)\n",
    "\n",
    "        # 4) total\n",
    "        err_total = (\n",
    "            err_mu + err_beta + err_B0\n",
    "            + err_B1_off + err_B2_off + err_B12_diag\n",
    "        )\n",
    "\n",
    "        # accumulate\n",
    "        acc[\"err_mu\"].append(err_mu)\n",
    "        acc[\"err_beta\"].append(err_beta)\n",
    "        acc[\"err_B0\"].append(err_B0)\n",
    "        acc[\"err_B1_off\"].append(err_B1_off)\n",
    "        acc[\"err_B2_off\"].append(err_B2_off)\n",
    "        acc[\"err_B12_diag\"].append(err_B12_diag)\n",
    "        acc[\"err_total\"].append(err_total)\n",
    "\n",
    "    # mean of 20 repetitions\n",
    "    results.append({\n",
    "        \"nb_sequences\":     nb_sequences,\n",
    "        \"mean_err_mu\":      np.mean(acc[\"err_mu\"]),\n",
    "        \"mean_err_beta\":    np.mean(acc[\"err_beta\"]),\n",
    "        \"mean_err_B0\":      np.mean(acc[\"err_B0\"]),\n",
    "        \"mean_err_B1_off\":  np.mean(acc[\"err_B1_off\"]),\n",
    "        \"mean_err_B2_off\":  np.mean(acc[\"err_B2_off\"]),\n",
    "        \"mean_err_B12_diag\":np.mean(acc[\"err_B12_diag\"]),\n",
    "        \"mean_err_total\":   np.mean(acc[\"err_total\"]),\n",
    "    })\n",
    "\n",
    "# ------------------ Export CSV ------------------\n",
    "df = pd.DataFrame(results)\n",
    "df.to_csv(\"errors_with_features.csv\", index=False)\n",
    "print(df)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48c8be5b",
   "metadata": {},
   "source": [
    "### xP for annexe D2 (D and alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369039cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================================\n",
    "# Estimation errors on α and D = α (I - α)^(-1) vs ground truth\n",
    "# with feature-based logistic Hawkes (D=3), across seq_list.\n",
    "# ==========================================================\n",
    "\n",
    "# ------------------ Helpers ------------------\n",
    "def branching_matrix(alpha: np.ndarray, clip_rho: float = 0.995, jitter: float = 1e-10) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Compute D = α (I - α)^(-1) in a numerically stable way.\n",
    "    If spectral radius ρ(α) >= clip_rho, scale α <- α * (clip_rho / ρ(α)).\n",
    "    Add a small jitter on the diagonal if (I-α) is ill-conditioned.\n",
    "    \"\"\"\n",
    "    # Stability guard\n",
    "    rho = np.max(np.abs(np.linalg.eigvals(alpha)))\n",
    "    if rho >= clip_rho:\n",
    "        alpha = alpha * (clip_rho / rho)\n",
    "\n",
    "    Ddim = alpha.shape[0]\n",
    "    I = np.eye(Ddim)\n",
    "\n",
    "    M = I - alpha\n",
    "    # Jitter if ill-conditioned\n",
    "    if np.linalg.cond(M) > 1e10:\n",
    "        M = M + jitter * I\n",
    "    # Solve for α (I-α)^(-1) without explicit inverse:\n",
    "    # α (I-α)^(-1) = ( (I-α)^{-T} α^T )^T\n",
    "    return (np.linalg.solve(M, alpha.T).T)\n",
    "\n",
    "def mse(A: np.ndarray, B: np.ndarray) -> float:\n",
    "    return float(np.mean((A - B) ** 2))\n",
    "\n",
    "def mse_off(A: np.ndarray, B: np.ndarray) -> float:\n",
    "    mask_off = ~np.eye(A.shape[0], dtype=bool)\n",
    "    return float(np.mean((A[mask_off] - B[mask_off]) ** 2))\n",
    "\n",
    "def mse_diag(A: np.ndarray, B: np.ndarray) -> float:\n",
    "    return float(np.mean((np.diag(A) - np.diag(B)) ** 2))\n",
    "\n",
    "def compute_alpha_and_D_hat(features_list, gamma_hat, theta1_hat, theta2_hat):\n",
    "    \"\"\"\n",
    "    Reconstruct per-sequence α̂ and D̂ from estimated parameters and features.\n",
    "\n",
    "    features_list : list of arrays, shape (K,D) or (D,) if K=1 per sequence\n",
    "    gamma_hat     : (D,D)\n",
    "    theta1_hat    : (K,D,D) or (D,D) if K=1 (will be expanded to (1,D,D))\n",
    "    theta2_hat    : (K,D,D) or (D,D) if K=1 (will be expanded to (1,D,D))\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    alphas_hat : list[np.ndarray] of shape (D,D)\n",
    "    Ds_hat     : list[np.ndarray] of shape (D,D) with D̂ = α̂ (I-α̂)^(-1)\n",
    "    \"\"\"\n",
    "    gamma_hat  = np.asarray(gamma_hat)\n",
    "    theta1_hat = np.asarray(theta1_hat)\n",
    "    theta2_hat = np.asarray(theta2_hat)\n",
    "\n",
    "    # Expand K-dim if compressed\n",
    "    if theta1_hat.ndim == 2:\n",
    "        theta1_hat = theta1_hat[None, ...]\n",
    "        theta2_hat = theta2_hat[None, ...]\n",
    "    K, Ddim, _ = theta1_hat.shape\n",
    "\n",
    "    alphas_hat, Ds_hat = [], []\n",
    "    for feats in features_list:\n",
    "        feats = np.asarray(feats)\n",
    "        if feats.ndim == 1:  # (D,) -> (K,D)\n",
    "            feats = np.tile(feats[None, :], (K, 1))\n",
    "\n",
    "        # z_ij = γ_ij + sum_k feats[k,i]*θ1[k,i,j] + sum_k feats[k,j]*θ2[k,i,j]\n",
    "        contrib1 = np.einsum('ki,kij->ij', feats, theta1_hat)\n",
    "        contrib2 = np.einsum('kj,kij->ij', feats, theta2_hat)\n",
    "        z = gamma_hat + contrib1 + contrib2\n",
    "\n",
    "        alpha_hat = 1.0 / (1.0 + np.exp(-z))  # sigmoid\n",
    "        alphas_hat.append(alpha_hat)\n",
    "        Ds_hat.append(branching_matrix(alpha_hat))\n",
    "\n",
    "    return alphas_hat, Ds_hat\n",
    "\n",
    "# ------------------ Ground-truth parameters ------------------\n",
    "D = 3\n",
    "mu_true   = np.array([0.1, 0.15, 0.2])\n",
    "beta_true = np.array([1.0, 1.0, 1.0])  # not unique\n",
    "\n",
    "B0 = np.array([\n",
    "    [-1.0, -0.5, -1.0],\n",
    "    [-1.0, -0.5, -1.0],\n",
    "    [-0.5,  0.0,  0.0]\n",
    "])\n",
    "B1 = np.array([\n",
    "    [ 0.0,   0.0,   0.0],\n",
    "    [ 0.0,  -0.75,  0.50],\n",
    "    [ 0.0,  -1.00,  0.0]\n",
    "])\n",
    "B2 = np.array([\n",
    "    [ 0.0,   0.0,   0.75],\n",
    "    [ 0.50,  0.0,  -0.50],\n",
    "    [ 0.75, -0.25,  0.0]\n",
    "])\n",
    "\n",
    "max_events  = 300\n",
    "seq_list    = [5, 20, 50, 70]\n",
    "num_repeats = 20\n",
    "\n",
    "results = []\n",
    "\n",
    "# ------------------ Main loop ------------------\n",
    "for nb_sequences in tqdm(seq_list, desc=\"Sequences\"):\n",
    "    acc = {\n",
    "        \"err_mu\": [], \"err_beta\": [], \"err_B0\": [],\n",
    "        \"err_B1_off\": [], \"err_B2_off\": [], \"err_B12_diag\": [],\n",
    "        \"err_alpha\": [], \"err_alpha_off\": [], \"err_alpha_diag\": [],\n",
    "        \"err_D\": [], \"err_total\": []\n",
    "    }\n",
    "\n",
    "    for rep in range(num_repeats):\n",
    "        all_times, all_types, all_x = [], [], []\n",
    "        all_alpha_true, all_D_true = [], []\n",
    "        rng = np.random.default_rng(1234 + rep)\n",
    "\n",
    "        # --- Generate nb_sequences trajectories ---\n",
    "        for _ in range(nb_sequences):\n",
    "            alpha_mat, x_vec = generate_alphas_with_features(rng, B0, B1, B2)\n",
    "            proc = ExpHawkes(mu_true, alpha_mat, beta_true)\n",
    "            ev   = proc.simulate(max_events=max_events)\n",
    "\n",
    "            all_times.extend(ev.times)\n",
    "            all_types.extend(ev.types)\n",
    "            all_x.append(x_vec)\n",
    "\n",
    "            all_alpha_true.append(alpha_mat)\n",
    "            all_D_true.append(branching_matrix(alpha_mat))\n",
    "\n",
    "        # --- Fit model ---\n",
    "        model = fit_hawkes(\n",
    "            nb_types       = D,\n",
    "            times          = all_times,\n",
    "            types          = all_types,\n",
    "            seq_feats      = all_x,\n",
    "            nb_batches     = nb_sequences,\n",
    "            lr             = 0.001,\n",
    "            nb_of_features = 1,\n",
    "            beta_unique    = True\n",
    "        )\n",
    "\n",
    "        # --- Extract estimated parameters ---\n",
    "        mu_hat   = model.mu.detach().cpu().numpy()\n",
    "        beta_hat = model.beta.detach().cpu().numpy()\n",
    "        B0_hat   = model.gamma.detach().cpu().numpy()\n",
    "        B1_hat   = np.squeeze(model.theta1.detach().cpu().numpy())\n",
    "        B2_hat   = np.squeeze(model.theta2.detach().cpu().numpy())\n",
    "\n",
    "        # --- Parameter errors (MSE) ---\n",
    "        err_mu   = mse(mu_hat,   mu_true)\n",
    "        err_beta = mse(beta_hat, beta_true)\n",
    "        err_B0   = mse(B0_hat,   B0)\n",
    "\n",
    "        err_B1_off = mse_off(B1_hat, B1)\n",
    "        err_B2_off = mse_off(B2_hat, B2)\n",
    "\n",
    "        true_diag = np.diag(B1) + np.diag(B2)\n",
    "        est_diag  = np.diag(B1_hat) + np.diag(B2_hat)\n",
    "        err_B12_diag = float(np.mean((est_diag - true_diag) ** 2))\n",
    "\n",
    "        # --- α̂ and D̂ per sequence ---\n",
    "        alphas_hat, Ds_hat = compute_alpha_and_D_hat(all_x, B0_hat, B1_hat, B2_hat)\n",
    "\n",
    "        # --- α errors averaged over sequences ---\n",
    "        ea = eo = ed = 0.0\n",
    "        for A_true, A_hat in zip(all_alpha_true, alphas_hat):\n",
    "            ea += mse(A_hat, A_true)\n",
    "            eo += mse_off(A_hat, A_true)\n",
    "            ed += mse_diag(A_hat, A_true)\n",
    "        S = len(all_alpha_true)\n",
    "        ea /= S; eo /= S; ed /= S\n",
    "\n",
    "        # --- D errors averaged over sequences ---\n",
    "        eD = 0.0\n",
    "        for D_true, D_hat in zip(all_D_true, Ds_hat):\n",
    "            eD += mse(D_hat, D_true)\n",
    "        eD /= S\n",
    "\n",
    "        # --- Total error (composite) ---\n",
    "        err_total = (err_mu + err_beta + err_B0\n",
    "                     + err_B1_off + err_B2_off + err_B12_diag\n",
    "                     + ea + eD)\n",
    "\n",
    "        # Accumulate\n",
    "        acc[\"err_mu\"].append(err_mu)\n",
    "        acc[\"err_beta\"].append(err_beta)\n",
    "        acc[\"err_B0\"].append(err_B0)\n",
    "        acc[\"err_B1_off\"].append(err_B1_off)\n",
    "        acc[\"err_B2_off\"].append(err_B2_off)\n",
    "        acc[\"err_B12_diag\"].append(err_B12_diag)\n",
    "        acc[\"err_alpha\"].append(ea)\n",
    "        acc[\"err_alpha_off\"].append(eo)\n",
    "        acc[\"err_alpha_diag\"].append(ed)\n",
    "        acc[\"err_D\"].append(eD)\n",
    "        acc[\"err_total\"].append(err_total)\n",
    "\n",
    "    # --- Mean over repetitions ---\n",
    "    results.append({\n",
    "        \"nb_sequences\":        nb_sequences,\n",
    "        \"mean_err_mu\":         float(np.mean(acc[\"err_mu\"])),\n",
    "        \"mean_err_beta\":       float(np.mean(acc[\"err_beta\"])),\n",
    "        \"mean_err_B0\":         float(np.mean(acc[\"err_B0\"])),\n",
    "        \"mean_err_B1_off\":     float(np.mean(acc[\"err_B1_off\"])),\n",
    "        \"mean_err_B2_off\":     float(np.mean(acc[\"err_B2_off\"])),\n",
    "        \"mean_err_B12_diag\":   float(np.mean(acc[\"err_B12_diag\"])),\n",
    "        \"mean_err_alpha\":      float(np.mean(acc[\"err_alpha\"])),\n",
    "        \"mean_err_alpha_off\":  float(np.mean(acc[\"err_alpha_off\"])),\n",
    "        \"mean_err_alpha_diag\": float(np.mean(acc[\"err_alpha_diag\"])),\n",
    "        \"mean_err_D\":          float(np.mean(acc[\"err_D\"])),\n",
    "        \"mean_err_total\":      float(np.mean(acc[\"err_total\"])),\n",
    "    })\n",
    "\n",
    "# ------------------ Export & Display ------------------\n",
    "df = pd.DataFrame(results)\n",
    "df.to_csv(\"errors_with_features_and_alpha.csv\", index=False)\n",
    "print(df.to_string(index=False))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cb23028",
   "metadata": {},
   "source": [
    "### xP for Annexe D2 (SHAP,Uplift)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de53d8a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================\n",
    "# Uplift & Shapley on simulated football Hawkes data (normalized-only)\n",
    "# 12 dims: 11 players + 1 \"dangerous surface\" node\n",
    "# ============================================\n",
    "\n",
    "# ---------- Roles (for display only) ----------\n",
    "roles = {\n",
    "    0: \"Goalkeeper\",\n",
    "    1: \"Right Fullback\",\n",
    "    2: \"Left Fullback\",\n",
    "    3: \"Right Centre-Back\",\n",
    "    4: \"Left Centre-Back\",\n",
    "    5: \"Defensive Midfielder\",\n",
    "    6: \"Right Central Midfielder\",\n",
    "    7: \"Left Central Midfielder\",\n",
    "    8: \"Right Winger\",\n",
    "    9: \"Striker\",\n",
    "    10:\"Left Winger\",\n",
    "    11:\"Dangerous Surface\",\n",
    "}\n",
    "\n",
    "# ---------- Problem size / indexing ----------\n",
    "D = 12\n",
    "surface_idx = 11\n",
    "players_idx = list(range(11))  # 0..10 (players only)\n",
    "\n",
    "# ---------- Feature scales (source/target) ----------\n",
    "w_src = np.array([0.01, 0.30, 0.30, 0.03, 0.03, 0.45, 0.60, 0.60, 0.95, 1.25, 0.95, 0.0])\n",
    "w_tgt = np.array([0.02, 0.15, 0.15, 0.05, 0.05, 0.30, 0.45, 0.45, 0.70, 0.95, 0.70, 0.0])\n",
    "\n",
    "# ---------- Connectivity mask M (j -> i) ----------\n",
    "M = np.zeros((D, D), dtype=float)\n",
    "\n",
    "def connect(srcs, tgts, w=1.0):\n",
    "    for j in srcs:\n",
    "        for i in tgts:\n",
    "            M[i, j] = w\n",
    "\n",
    "GK   = [0]\n",
    "LAT  = [1, 2]\n",
    "DC   = [3, 4]\n",
    "MID6 = [5]\n",
    "MID8 = [6, 7]\n",
    "WING = [8, 10]\n",
    "STR  = [9]\n",
    "\n",
    "connect(GK,  DC + LAT,                w=0.2)\n",
    "connect(DC,  LAT + MID6,              w=0.3)\n",
    "connect(LAT, MID6 + MID8 + WING + STR, w=0.6); connect(LAT, [surface_idx], w=0.3)\n",
    "connect(MID6, MID8 + WING + STR,      w=0.65)\n",
    "connect(MID8, WING + STR + [surface_idx], w=0.8)\n",
    "connect(WING + STR, [surface_idx],    w=1.0)\n",
    "\n",
    "# ---------- Logistic parameters (gamma, theta1, theta2) ----------\n",
    "rng = np.random.default_rng(42)\n",
    "gamma = -2.0 * np.ones((D, D))\n",
    "gamma[M == 0] = -8.0\n",
    "\n",
    "theta1_scale, theta2_scale = 1.0, 1.8\n",
    "# theta1: depends on source j (replicated by row), then masked\n",
    "theta1 = theta1_scale * np.tile(w_src, (D, 1)) * M\n",
    "# theta2: depends on receiver i (replicated by column), then masked\n",
    "theta2 = theta2_scale * np.tile(w_tgt.reshape(-1, 1), (1, D)) * M\n",
    "\n",
    "# no outgoing from surface\n",
    "theta1[:, surface_idx] = 0.0\n",
    "theta2[:, surface_idx] = 0.0\n",
    "\n",
    "# ---------- Helpers ----------\n",
    "def sigmoid(x):\n",
    "    return 1.0 / (1.0 + np.exp(-x))\n",
    "\n",
    "def branching_matrix(alpha):\n",
    "    \"\"\" D = alpha (I - alpha)^{-1} (solve-based, stable for rho(alpha)<1). \"\"\"\n",
    "    I = np.eye(alpha.shape[0])\n",
    "    return (np.linalg.solve(I - alpha, alpha.T).T)\n",
    "\n",
    "def normalize_1(v, eps=1e-12):\n",
    "    s = v.sum()\n",
    "    return v / (s + eps)\n",
    "\n",
    "# ---------- Simulation settings ----------\n",
    "max_events   = 500\n",
    "nb_sequences = 150\n",
    "mu_true  = np.linspace(0.05, 0.15, D)\n",
    "beta_true = np.full(D, 0.6)\n",
    "\n",
    "# Containers\n",
    "all_times, all_types, all_x = [], [], []\n",
    "all_alpha_true = []\n",
    "rng_seq = np.random.default_rng(1234)\n",
    "\n",
    "# NOTE: This cell assumes you already have:\n",
    "#   - ExpHawkes(mu, alpha, beta) with .simulate(max_events)\n",
    "#   - generate_alphas_with_features(rng, gamma, theta1, theta2) -> (alpha, x)\n",
    "#   - fit_hawkes(...) returning model with attributes: mu, beta, gamma, theta1, theta2\n",
    "# If you prefer the hand-crafted \"foot\" generator, replace the call below by your function.\n",
    "\n",
    "# ---------- Generate sequences ----------\n",
    "for _ in tqdm(range(nb_sequences), desc=\"Sim sequences\"):\n",
    "    alpha_mat, x_vec = generate_alphas_with_features(rng_seq, gamma, theta1, theta2)\n",
    "    proc = ExpHawkes(mu_true, alpha_mat, beta_true)\n",
    "    ev   = proc.simulate(max_events=max_events)\n",
    "    all_times.extend(ev.times)\n",
    "    all_types.extend(ev.types)\n",
    "    all_x.append(x_vec)\n",
    "    all_alpha_true.append(alpha_mat)\n",
    "\n",
    "# ---------- Fit Hawkes with features ----------\n",
    "model = fit_hawkes(\n",
    "    nb_types        = D,\n",
    "    times           = all_times,\n",
    "    types           = all_types,\n",
    "    seq_feats       = all_x,\n",
    "    nb_batches      = nb_sequences,\n",
    "    lr              = 0.003,\n",
    "    nb_of_features  = 1,\n",
    "    beta_unique     = True,\n",
    "    log_b           = True\n",
    ")\n",
    "\n",
    "# ---------- Reconstruct alpha-hat per sequence, then average ----------\n",
    "B0_hat = model.gamma.detach().cpu().numpy()\n",
    "B1_hat = np.squeeze(model.theta1.detach().cpu().numpy())\n",
    "B2_hat = np.squeeze(model.theta2.detach().cpu().numpy())\n",
    "\n",
    "def compute_alpha_hat(features_list, gamma_hat, theta1_hat, theta2_hat):\n",
    "    gamma_hat  = np.asarray(gamma_hat)\n",
    "    theta1_hat = np.asarray(theta1_hat)\n",
    "    theta2_hat = np.asarray(theta2_hat)\n",
    "    if theta1_hat.ndim == 2:\n",
    "        theta1_hat = theta1_hat[None, ...]\n",
    "        theta2_hat = theta2_hat[None, ...]\n",
    "    K, D, _ = theta1_hat.shape\n",
    "    alphas = []\n",
    "    for feats in features_list:\n",
    "        feats = np.asarray(feats)\n",
    "        if feats.ndim == 1:\n",
    "            feats = np.tile(feats[None, :], (K, 1))\n",
    "        z = (gamma_hat\n",
    "             + np.einsum('ki,kij->ij', feats, theta1_hat)\n",
    "             + np.einsum('kj,kij->ij', feats, theta2_hat))\n",
    "        alphas.append(sigmoid(z))\n",
    "    return alphas\n",
    "\n",
    "alphas_hat_list = compute_alpha_hat(all_x, B0_hat, B1_hat, B2_hat)\n",
    "hat_alpha = np.mean(alphas_hat_list, axis=0)\n",
    "\n",
    "# Stabilize if needed\n",
    "rho_hat = np.max(np.abs(np.linalg.eigvals(hat_alpha)))\n",
    "if rho_hat >= 1.0:\n",
    "    hat_alpha *= 0.99 / rho_hat\n",
    "\n",
    "# ---------- D_hat and normalized D12 (estimated) ----------\n",
    "D_hat = branching_matrix(hat_alpha)\n",
    "D12_hat = D_hat[surface_idx, players_idx]\n",
    "D12_hat_norm = normalize_1(D12_hat)\n",
    "\n",
    "# ---------- Uplift (estimated) ----------\n",
    "s_hat = D_hat[surface_idx, :].sum()\n",
    "uplift_est = []\n",
    "for j in players_idx:\n",
    "    keep = [i for i in range(D) if i != j]\n",
    "    sub_alpha = hat_alpha[np.ix_(keep, keep)]\n",
    "    # stability guard on sub_alpha\n",
    "    rho_sub = np.max(np.abs(np.linalg.eigvals(sub_alpha)))\n",
    "    if rho_sub >= 1.0:\n",
    "        sub_alpha = sub_alpha * (0.99 / rho_sub)\n",
    "    D_j = branching_matrix(sub_alpha)\n",
    "    surf_pos = keep.index(surface_idx)\n",
    "    s_j = D_j[surf_pos, :].sum()\n",
    "    uplift_est.append(s_hat - s_j)\n",
    "uplift_est = np.array(uplift_est)\n",
    "uplift_est_norm = normalize_1(uplift_est)\n",
    "\n",
    "# ---------- Shapley (estimated) ----------\n",
    "def value_of_subset(keep_players_set, alpha_full, surface_idx_full):\n",
    "    keep_idx = sorted(list(keep_players_set)) + [surface_idx_full]\n",
    "    sub_alpha = alpha_full[np.ix_(keep_idx, keep_idx)]\n",
    "    rho_sub = np.max(np.abs(np.linalg.eigvals(sub_alpha)))\n",
    "    if rho_sub >= 1.0:\n",
    "        sub_alpha = sub_alpha * (0.99 / rho_sub)\n",
    "    D = branching_matrix(sub_alpha)\n",
    "    surf_pos = keep_idx.index(surface_idx_full)\n",
    "    return D[surf_pos, :].sum()\n",
    "\n",
    "def shapley_values(alpha_full, players_idx, surface_idx_full):\n",
    "    n = len(players_idx)\n",
    "    fact = math.factorial\n",
    "    cache = {}\n",
    "\n",
    "    def v(S_tuple):\n",
    "        if S_tuple in cache:\n",
    "            return cache[S_tuple]\n",
    "        val = value_of_subset(set(S_tuple), alpha_full, surface_idx_full)\n",
    "        cache[S_tuple] = val\n",
    "        return val\n",
    "\n",
    "    phi = np.zeros(n)\n",
    "    for j_pos, j in enumerate(players_idx):\n",
    "        others = [p for p in players_idx if p != j]\n",
    "        for k in range(len(others) + 1):\n",
    "            for S in itertools.combinations(others, k):\n",
    "                S_with_j = tuple(sorted(S + (j,)))\n",
    "                contrib = v(S_with_j) - v(S)\n",
    "                weight = fact(k) * fact(n - k - 1) / fact(n)\n",
    "                phi[j_pos] += weight * contrib\n",
    "    return phi\n",
    "\n",
    "phi_est = shapley_values(hat_alpha, players_idx, surface_idx)\n",
    "phi_est_norm = normalize_1(phi_est)\n",
    "\n",
    "# ---------- Ground-truth normalized D12 ----------\n",
    "alpha_true = np.mean(np.stack(all_alpha_true, axis=0), axis=0)\n",
    "rho_true = np.max(np.abs(np.linalg.eigvals(alpha_true)))\n",
    "if rho_true >= 1.0:\n",
    "    alpha_true *= 0.99 / rho_true\n",
    "\n",
    "D_true = branching_matrix(alpha_true)\n",
    "D12_true = D_true[surface_idx, players_idx]\n",
    "D12_true_norm = normalize_1(D12_true)\n",
    "\n",
    "# ---------- Errors (normalized-only) & MSE ----------\n",
    "err_D12_norm        = D12_hat_norm      - D12_true_norm\n",
    "err_shapley_norm    = phi_est_norm      - D12_true_norm\n",
    "err_uplift_norm     = uplift_est_norm   - D12_true_norm\n",
    "\n",
    "mse_D12_norm     = float(np.mean(err_D12_norm**2))\n",
    "mse_shapley_norm = float(np.mean(err_shapley_norm**2))\n",
    "mse_uplift_norm  = float(np.mean(err_uplift_norm**2))\n",
    "\n",
    "# ---------- Final table (normalized-only) ----------\n",
    "df_comp = pd.DataFrame({\n",
    "    \"player_idx\": players_idx,\n",
    "    \"role\": [roles[i] for i in players_idx],\n",
    "    \"D12_true_norm\": D12_true_norm,\n",
    "    \"D12_norm\": D12_hat_norm,\n",
    "    \"Delta_D12\": err_D12_norm,\n",
    "    \"Shapley_norm\": phi_est_norm,\n",
    "    \"Delta_Shapley\": err_shapley_norm,\n",
    "    \"Uplift_norm\": uplift_est_norm,\n",
    "    \"Delta_Uplift\": err_uplift_norm,\n",
    "}).sort_values(\"D12_norm\", ascending=False).reset_index(drop=True)\n",
    "\n",
    "print(\"\\n=== Normalized comparison vs true D12 (surface -> players) ===\")\n",
    "print(df_comp.to_string(index=False))\n",
    "\n",
    "print(\"\\n=== Global MSEs (normalized-only, vs true D12) ===\")\n",
    "print(f\"MSE(D12_norm)     = {mse_D12_norm:.6g}\")\n",
    "print(f\"MSE(Shapley_norm) = {mse_shapley_norm:.6g}\")\n",
    "print(f\"MSE(Uplift_norm)  = {mse_uplift_norm:.6g}\")\n",
    "\n",
    "# ---------- Plot (normalized-only) ----------\n",
    "plt.figure(figsize=(9,4.5))\n",
    "x = np.arange(len(df_comp))\n",
    "w = 0.22\n",
    "plt.bar(x - w,            df_comp[\"D12_true_norm\"], width=w, label=\"True D12 (norm.)\")\n",
    "plt.bar(x,                df_comp[\"D12_norm\"],      width=w, label=\"D12 (norm., est.)\")\n",
    "plt.bar(x + w,            df_comp[\"Shapley_norm\"],  width=w, label=\"Shapley (norm.)\")\n",
    "plt.bar(x + 2*w,          df_comp[\"Uplift_norm\"],   width=w, label=\"Uplift (norm.)\")\n",
    "plt.xticks(x, df_comp[\"role\"], rotation=45, ha='right')\n",
    "plt.ylabel(\"Normalized contribution (sums to 1)\")\n",
    "plt.title(\"Per-player normalized attributions vs true D12\")\n",
    "plt.legend(ncol=2)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
