{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "31408e0a-71b9-4bc7-af4a-7286faa79ebf",
   "metadata": {},
   "source": [
    "# Processing for Spotify data\n",
    "\n",
    "#### This version is based on \"**WSDM Cup:The Music Streaming Sessions Dataset**\" at https://research.atspotify.com/datasets/. Using \"Test_Set.tar.gz (14G)\". After unzip, we randome choose a file for creating the simulator (\"log_prehistory_20180918_000000000000.csv\", 648.37MB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "908b93a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1b5db932",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read 3942478 rows.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>session_position</th>\n",
       "      <th>session_length</th>\n",
       "      <th>track_id_clean</th>\n",
       "      <th>skip_3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_f8fc3210-d734-416a-aac9-0ad43c78e9b4</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_7e9d453e-035c-46b4-a05d-dc631ec42eff</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_f8fc3210-d734-416a-aac9-0ad43c78e9b4</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_7e9d453e-035c-46b4-a05d-dc631ec42eff</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_2426ed3d-aa65-49cf-b323-66d5cd7bedae</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_5c1d21a3-0d09-4098-b783-09c180e2dcb5</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_2426ed3d-aa65-49cf-b323-66d5cd7bedae</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_889887f2-02b4-4cb7-8207-c7df6b115854</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_5266cf52-3ffe-4dc4-9a1e-7ead18ddc8ad</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_6b18af74-691c-4974-9aad-317383a1c392</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_33654513-24f3-452b-b3ed-a2562d75168f</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8bcddd46-15f6-40ad-9100-5ea9e0b62fb5</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_5ce63b97-5d66-4bcc-8527-44b4bc9e8c4f</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_9879ec58-8688-44e8-85bd-fb93b3a9524f</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_cfae3822-2bd1-4fe6-ba31-a8a715b40a20</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_ae38ec16-ef66-4ee9-bf3d-b69ebc14d08a</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c2feb7b1-c4f3-4e68-80b6-72762dae6f06</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_d09faaf2-ed4d-4e47-b27e-9c4906a29771</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 session_id  session_position  session_length  \\\n",
       "0   65_00000cda-eb14-42b5-be70-64afd00b159e                 1              20   \n",
       "1   65_00000cda-eb14-42b5-be70-64afd00b159e                 2              20   \n",
       "2   65_00000cda-eb14-42b5-be70-64afd00b159e                 3              20   \n",
       "3   65_00000cda-eb14-42b5-be70-64afd00b159e                 4              20   \n",
       "4   65_00000cda-eb14-42b5-be70-64afd00b159e                 5              20   \n",
       "5   65_00000cda-eb14-42b5-be70-64afd00b159e                 6              20   \n",
       "6   65_00000cda-eb14-42b5-be70-64afd00b159e                 7              20   \n",
       "7   65_00000cda-eb14-42b5-be70-64afd00b159e                 8              20   \n",
       "8   65_00000cda-eb14-42b5-be70-64afd00b159e                 9              20   \n",
       "9   65_00000cda-eb14-42b5-be70-64afd00b159e                10              20   \n",
       "10  65_0000715a-50d0-413c-b977-a46c76f22748                 1              20   \n",
       "11  65_0000715a-50d0-413c-b977-a46c76f22748                 2              20   \n",
       "12  65_0000715a-50d0-413c-b977-a46c76f22748                 3              20   \n",
       "13  65_0000715a-50d0-413c-b977-a46c76f22748                 4              20   \n",
       "14  65_0000715a-50d0-413c-b977-a46c76f22748                 5              20   \n",
       "15  65_0000715a-50d0-413c-b977-a46c76f22748                 6              20   \n",
       "16  65_0000715a-50d0-413c-b977-a46c76f22748                 7              20   \n",
       "17  65_0000715a-50d0-413c-b977-a46c76f22748                 8              20   \n",
       "18  65_0000715a-50d0-413c-b977-a46c76f22748                 9              20   \n",
       "19  65_0000715a-50d0-413c-b977-a46c76f22748                10              20   \n",
       "\n",
       "                            track_id_clean  skip_3  \n",
       "0   t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6   False  \n",
       "1   t_f8fc3210-d734-416a-aac9-0ad43c78e9b4   False  \n",
       "2   t_7e9d453e-035c-46b4-a05d-dc631ec42eff    True  \n",
       "3   t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6   False  \n",
       "4   t_f8fc3210-d734-416a-aac9-0ad43c78e9b4   False  \n",
       "5   t_7e9d453e-035c-46b4-a05d-dc631ec42eff    True  \n",
       "6   t_2426ed3d-aa65-49cf-b323-66d5cd7bedae    True  \n",
       "7   t_5c1d21a3-0d09-4098-b783-09c180e2dcb5    True  \n",
       "8   t_2426ed3d-aa65-49cf-b323-66d5cd7bedae   False  \n",
       "9   t_889887f2-02b4-4cb7-8207-c7df6b115854    True  \n",
       "10  t_5266cf52-3ffe-4dc4-9a1e-7ead18ddc8ad    True  \n",
       "11  t_6b18af74-691c-4974-9aad-317383a1c392   False  \n",
       "12  t_33654513-24f3-452b-b3ed-a2562d75168f   False  \n",
       "13  t_8bcddd46-15f6-40ad-9100-5ea9e0b62fb5   False  \n",
       "14  t_5ce63b97-5d66-4bcc-8527-44b4bc9e8c4f    True  \n",
       "15  t_9879ec58-8688-44e8-85bd-fb93b3a9524f    True  \n",
       "16  t_cfae3822-2bd1-4fe6-ba31-a8a715b40a20    True  \n",
       "17  t_ae38ec16-ef66-4ee9-bf3d-b69ebc14d08a    True  \n",
       "18  t_c2feb7b1-c4f3-4e68-80b6-72762dae6f06    True  \n",
       "19  t_d09faaf2-ed4d-4e47-b27e-9c4906a29771    True  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(filepath_or_buffer='./log_prehistory_20180918_000000000000.csv')\n",
    "\n",
    "\n",
    "df = df[['session_id', 'session_position', 'session_length', 'track_id_clean', 'skip_3']]\n",
    "print(f\"Read {len(df)} rows.\")\n",
    "df.head(20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d2adbd5",
   "metadata": {},
   "source": [
    "#### Note for some reason, the session_length is 20 but the session_position is only halfed. This is consistent along the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "373ee9e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before filter out sequence length, we have [20 15 10 11 13 14 17 12 18 19 16] possible sequence length.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "count    483798.000000\n",
       "mean         16.544361\n",
       "std           3.835032\n",
       "min          10.000000\n",
       "25%          13.000000\n",
       "50%          18.000000\n",
       "75%          20.000000\n",
       "max          20.000000\n",
       "Name: session_length, dtype: float64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unique_sessions = df.drop_duplicates(subset=['session_id'])\n",
    "print(f\"Before filter out sequence length, we have {unique_sessions['session_length'].unique()} possible sequence length.\")\n",
    "unique_sessions['session_length'].describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d8f7afc",
   "metadata": {},
   "source": [
    "#### It turns out, a sequence length of 20 is the most frequent, so we choose 20 as the default sequence length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "76e5c71b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keep 2284600 rows, which is 57.95% of initial length.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>session_position</th>\n",
       "      <th>session_length</th>\n",
       "      <th>track_id_clean</th>\n",
       "      <th>skip_3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_f8fc3210-d734-416a-aac9-0ad43c78e9b4</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_7e9d453e-035c-46b4-a05d-dc631ec42eff</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_f8fc3210-d734-416a-aac9-0ad43c78e9b4</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_7e9d453e-035c-46b4-a05d-dc631ec42eff</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_2426ed3d-aa65-49cf-b323-66d5cd7bedae</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_5c1d21a3-0d09-4098-b783-09c180e2dcb5</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_2426ed3d-aa65-49cf-b323-66d5cd7bedae</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>65_00000cda-eb14-42b5-be70-64afd00b159e</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_889887f2-02b4-4cb7-8207-c7df6b115854</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_5266cf52-3ffe-4dc4-9a1e-7ead18ddc8ad</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_6b18af74-691c-4974-9aad-317383a1c392</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_33654513-24f3-452b-b3ed-a2562d75168f</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8bcddd46-15f6-40ad-9100-5ea9e0b62fb5</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_5ce63b97-5d66-4bcc-8527-44b4bc9e8c4f</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_9879ec58-8688-44e8-85bd-fb93b3a9524f</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_cfae3822-2bd1-4fe6-ba31-a8a715b40a20</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_ae38ec16-ef66-4ee9-bf3d-b69ebc14d08a</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c2feb7b1-c4f3-4e68-80b6-72762dae6f06</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>65_0000715a-50d0-413c-b977-a46c76f22748</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_d09faaf2-ed4d-4e47-b27e-9c4906a29771</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 session_id  session_position  session_length  \\\n",
       "0   65_00000cda-eb14-42b5-be70-64afd00b159e                 1              20   \n",
       "1   65_00000cda-eb14-42b5-be70-64afd00b159e                 2              20   \n",
       "2   65_00000cda-eb14-42b5-be70-64afd00b159e                 3              20   \n",
       "3   65_00000cda-eb14-42b5-be70-64afd00b159e                 4              20   \n",
       "4   65_00000cda-eb14-42b5-be70-64afd00b159e                 5              20   \n",
       "5   65_00000cda-eb14-42b5-be70-64afd00b159e                 6              20   \n",
       "6   65_00000cda-eb14-42b5-be70-64afd00b159e                 7              20   \n",
       "7   65_00000cda-eb14-42b5-be70-64afd00b159e                 8              20   \n",
       "8   65_00000cda-eb14-42b5-be70-64afd00b159e                 9              20   \n",
       "9   65_00000cda-eb14-42b5-be70-64afd00b159e                10              20   \n",
       "10  65_0000715a-50d0-413c-b977-a46c76f22748                 1              20   \n",
       "11  65_0000715a-50d0-413c-b977-a46c76f22748                 2              20   \n",
       "12  65_0000715a-50d0-413c-b977-a46c76f22748                 3              20   \n",
       "13  65_0000715a-50d0-413c-b977-a46c76f22748                 4              20   \n",
       "14  65_0000715a-50d0-413c-b977-a46c76f22748                 5              20   \n",
       "15  65_0000715a-50d0-413c-b977-a46c76f22748                 6              20   \n",
       "16  65_0000715a-50d0-413c-b977-a46c76f22748                 7              20   \n",
       "17  65_0000715a-50d0-413c-b977-a46c76f22748                 8              20   \n",
       "18  65_0000715a-50d0-413c-b977-a46c76f22748                 9              20   \n",
       "19  65_0000715a-50d0-413c-b977-a46c76f22748                10              20   \n",
       "\n",
       "                            track_id_clean  skip_3  \n",
       "0   t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6   False  \n",
       "1   t_f8fc3210-d734-416a-aac9-0ad43c78e9b4   False  \n",
       "2   t_7e9d453e-035c-46b4-a05d-dc631ec42eff    True  \n",
       "3   t_442fb6aa-1c6a-48db-81f2-f820df5bc4c6   False  \n",
       "4   t_f8fc3210-d734-416a-aac9-0ad43c78e9b4   False  \n",
       "5   t_7e9d453e-035c-46b4-a05d-dc631ec42eff    True  \n",
       "6   t_2426ed3d-aa65-49cf-b323-66d5cd7bedae    True  \n",
       "7   t_5c1d21a3-0d09-4098-b783-09c180e2dcb5    True  \n",
       "8   t_2426ed3d-aa65-49cf-b323-66d5cd7bedae   False  \n",
       "9   t_889887f2-02b4-4cb7-8207-c7df6b115854    True  \n",
       "10  t_5266cf52-3ffe-4dc4-9a1e-7ead18ddc8ad    True  \n",
       "11  t_6b18af74-691c-4974-9aad-317383a1c392   False  \n",
       "12  t_33654513-24f3-452b-b3ed-a2562d75168f   False  \n",
       "13  t_8bcddd46-15f6-40ad-9100-5ea9e0b62fb5   False  \n",
       "14  t_5ce63b97-5d66-4bcc-8527-44b4bc9e8c4f    True  \n",
       "15  t_9879ec58-8688-44e8-85bd-fb93b3a9524f    True  \n",
       "16  t_cfae3822-2bd1-4fe6-ba31-a8a715b40a20    True  \n",
       "17  t_ae38ec16-ef66-4ee9-bf3d-b69ebc14d08a    True  \n",
       "18  t_c2feb7b1-c4f3-4e68-80b6-72762dae6f06    True  \n",
       "19  t_d09faaf2-ed4d-4e47-b27e-9c4906a29771    True  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_20 = df[df['session_length'] == 20]\n",
    "print(f\"Keep {len(df_20)} rows, which is {np.round(len(df_20)/len(df)*100,decimals=2)}% of initial length.\")\n",
    "df_20.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "245160f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After filter out sequence length, we have 228460 session\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After filter out sequence length, we have 301783 track_id\n"
     ]
    }
   ],
   "source": [
    "print(f\"After filter out sequence length, we have {df_20['session_id'].nunique()} session\")\n",
    "print(f\"After filter out sequence length, we have {df_20['track_id_clean'].nunique()} track_id\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "36440826",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set random seed for reproducibility\n",
    "global_seed = 42\n",
    "\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0830b59",
   "metadata": {},
   "source": [
    "#### Dealing with a very sparse and large matrix is memory consuming, we wish to handle a smaller matrix. To limit the matrix size, we sample 6000 unique session_id to create the sparse matrix. This is not necessary but help us to down size the matrix.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e821e34f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With a subpopulation of 6000, we have 6000 session and 26707 track_id.\n"
     ]
    }
   ],
   "source": [
    "population_size = 6000\n",
    "\n",
    "sampled_session_ids = df_20['session_id'].drop_duplicates().sample(n=population_size, random_state=global_seed)\n",
    "df_20_filtered = df_20[df_20['session_id'].isin(sampled_session_ids)]\n",
    "df_20_filtered.reset_index(drop=True, inplace=True)\n",
    "print(f\"With a subpopulation of {population_size}, we have {df_20_filtered['session_id'].nunique()} session and {df_20_filtered['track_id_clean'].nunique()} track_id.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9297f08c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>track_id_clean</th>\n",
       "      <th>t_0000496c-869f-4350-83f1-bd1c14ea79ba</th>\n",
       "      <th>t_00042d9b-e795-41a9-89ad-504373dd4287</th>\n",
       "      <th>t_000518e0-996c-46b0-9167-f831f5f8f513</th>\n",
       "      <th>t_00059b93-803a-4c3c-8ea6-d9938bd17264</th>\n",
       "      <th>t_0007a9bf-2faf-4345-b005-388b1d9e3d94</th>\n",
       "      <th>t_000b1102-79f5-497b-848b-6e793089b0ea</th>\n",
       "      <th>t_00174427-afef-4a1c-be96-38c2b6f01396</th>\n",
       "      <th>t_0017689d-4a22-480d-b313-93e38bb63ac1</th>\n",
       "      <th>t_0019b9be-3ea4-4722-800f-2b37cbe5a4b8</th>\n",
       "      <th>t_001e01c6-e635-4a72-8837-dad6c2389b8d</th>\n",
       "      <th>...</th>\n",
       "      <th>t_ffdf73bf-7b3b-476a-9031-e262fa3339e6</th>\n",
       "      <th>t_ffe1e63d-7f6d-47d9-9ecf-65db05ae7ccb</th>\n",
       "      <th>t_ffe3816d-baf0-4ddf-81ba-7598a03b2973</th>\n",
       "      <th>t_ffe4f45a-c85b-487e-b909-c9a401ab9130</th>\n",
       "      <th>t_ffeb7c85-d909-4b98-8fa4-1d7488faea99</th>\n",
       "      <th>t_ffecf968-2ea1-4fda-98e0-ebd43e313425</th>\n",
       "      <th>t_ffefb779-813a-4cd4-ba52-9fa2fa2bdf38</th>\n",
       "      <th>t_fff12ea1-83ef-468b-9fa4-eb80c8c1f8d4</th>\n",
       "      <th>t_fff9b035-3943-4557-bdd0-15fc2c5dc03e</th>\n",
       "      <th>t_fffe4528-f29c-467c-a751-e9db03964faa</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>session_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_0026980c-c0de-46bd-860b-02b4050bf72c</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_00285fa4-2c95-45af-a57e-1b73de06b895</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_0028933d-c089-494a-8d17-900f12b7af10</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_ffe678b7-a111-4f84-b3a5-ad559509fc7f</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_ffeb15fb-1e48-4469-808d-ca069506f21d</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65_fff05e64-c7ac-4aec-ade1-65dee71e4d14</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6000 rows × 26707 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "track_id_clean                           t_0000496c-869f-4350-83f1-bd1c14ea79ba  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_00042d9b-e795-41a9-89ad-504373dd4287  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_000518e0-996c-46b0-9167-f831f5f8f513  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_00059b93-803a-4c3c-8ea6-d9938bd17264  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_0007a9bf-2faf-4345-b005-388b1d9e3d94  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_000b1102-79f5-497b-848b-6e793089b0ea  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_00174427-afef-4a1c-be96-38c2b6f01396  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_0017689d-4a22-480d-b313-93e38bb63ac1  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_0019b9be-3ea4-4722-800f-2b37cbe5a4b8  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_001e01c6-e635-4a72-8837-dad6c2389b8d  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           ...  \\\n",
       "session_id                               ...   \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51  ...   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957  ...   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c  ...   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895  ...   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10  ...   \n",
       "...                                      ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee  ...   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f  ...   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d  ...   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4  ...   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14  ...   \n",
       "\n",
       "track_id_clean                           t_ffdf73bf-7b3b-476a-9031-e262fa3339e6  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_ffe1e63d-7f6d-47d9-9ecf-65db05ae7ccb  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_ffe3816d-baf0-4ddf-81ba-7598a03b2973  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_ffe4f45a-c85b-487e-b909-c9a401ab9130  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_ffeb7c85-d909-4b98-8fa4-1d7488faea99  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_ffecf968-2ea1-4fda-98e0-ebd43e313425  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_ffefb779-813a-4cd4-ba52-9fa2fa2bdf38  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_fff12ea1-83ef-468b-9fa4-eb80c8c1f8d4  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_fff9b035-3943-4557-bdd0-15fc2c5dc03e  \\\n",
       "session_id                                                                        \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0   \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0   \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0   \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0   \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0   \n",
       "...                                                                         ...   \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0   \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0   \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0   \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0   \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0   \n",
       "\n",
       "track_id_clean                           t_fffe4528-f29c-467c-a751-e9db03964faa  \n",
       "session_id                                                                       \n",
       "65_0008dac3-8379-4369-90a1-b7ceccbb2d51                                       0  \n",
       "65_0025a3cc-485b-407e-9fc6-2b58fc07b957                                       0  \n",
       "65_0026980c-c0de-46bd-860b-02b4050bf72c                                       0  \n",
       "65_00285fa4-2c95-45af-a57e-1b73de06b895                                       0  \n",
       "65_0028933d-c089-494a-8d17-900f12b7af10                                       0  \n",
       "...                                                                         ...  \n",
       "65_ffe2d9f1-1ffa-4a39-bbff-53c7b7a382ee                                       0  \n",
       "65_ffe678b7-a111-4f84-b3a5-ad559509fc7f                                       0  \n",
       "65_ffeb15fb-1e48-4469-808d-ca069506f21d                                       0  \n",
       "65_ffefaee3-5f51-4cd3-bb0b-68143cd4cec4                                       0  \n",
       "65_fff05e64-c7ac-4aec-ade1-65dee71e4d14                                       0  \n",
       "\n",
       "[6000 rows x 26707 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "session_track_matrix = df_20_filtered.pivot_table(index='session_id', columns='track_id_clean', aggfunc='size', fill_value=0)\n",
    "session_track_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "77384bcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "session_track_matrix_ = session_track_matrix.to_numpy()\n",
    "\n",
    "u, s, vt = np.linalg.svd(session_track_matrix_, full_matrices=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62d57ec8",
   "metadata": {},
   "source": [
    "#### Looking at the data, the numerical values are too extreme, we wish to standardize them first. This also help speed up clustering later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "dca3f0d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def standardize(embeddings):\n",
    "    # Compute the mean and standard deviation for each feature\n",
    "    mean_vals = np.mean(embeddings, axis=0)\n",
    "    std_vals = np.std(embeddings, axis=0)\n",
    "    # Avoid division by zero by setting any zero std values to 1\n",
    "    std_vals[std_vals == 0] = 1\n",
    "    # Standardize the embeddings\n",
    "    return (embeddings - mean_vals) / std_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "963d1e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalised_u = standardize(u[:, :50])\n",
    "normalised_v = standardize(vt.T[:, :50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8d350250",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Track embedding shape of (26707, 10) and session embedding shape of (6000, 10)\n"
     ]
    }
   ],
   "source": [
    "l = 10\n",
    "\n",
    "session_embeddings = normalised_u[:, :l]\n",
    "track_embeddings = normalised_v[:, :l]\n",
    "\n",
    "print(f\"Track embedding shape of {track_embeddings.shape} and session embedding shape of {session_embeddings.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f8c46405",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.3865183 , -0.04937551,  0.07676365, -0.08082001,  0.12558457,\n",
       "       -0.02435866, -0.05186017,  0.07997957, -0.11185562, -0.11709471])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "session_embeddings[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "942508c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.04759581, -0.08667425,  0.04949918, -0.00480566,  0.02343548,\n",
       "        0.03090921,  0.0017227 , -0.01937166, -0.00455287, -0.02484083])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "track_embeddings[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ba5775a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(26707,)\n",
      "Cluster distribution: [26661    12    11     1     1     6     1     4     6     4]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "\n",
    "num_clusters = l\n",
    "\n",
    "kmeans = KMeans(n_init='auto', n_clusters=num_clusters, random_state=1234)\n",
    "kmeans.fit(track_embeddings)\n",
    "track_clusters = kmeans.labels_\n",
    "print(track_clusters.shape)\n",
    "\n",
    "cluster_distribution = np.bincount(track_clusters, minlength=num_clusters)\n",
    "print(\"Cluster distribution:\", cluster_distribution)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c1a1319",
   "metadata": {},
   "source": [
    "#### We observe using a stardard Kmenas algorithm without constrains results in a highly imbalanced cluster. With additional constratins on the size of the cluster, we get more reasonable results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f741dd50",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Balanced cluster distribution: [2720 2627 2720 2720 2720 2620 2620 2620 2720 2620]\n"
     ]
    }
   ],
   "source": [
    "from k_means_constrained import KMeansConstrained\n",
    "\n",
    "num_clusters = l\n",
    "min_size = len(track_embeddings) // num_clusters - 50\n",
    "max_size = len(track_embeddings) // num_clusters + 50\n",
    "\n",
    "clf = KMeansConstrained(\n",
    "    n_clusters=num_clusters,\n",
    "    size_min=min_size,\n",
    "    size_max=max_size,\n",
    "    random_state=1234\n",
    ")\n",
    "\n",
    "clf.fit(track_embeddings)\n",
    "track_clusters_balanced = clf.labels_\n",
    "\n",
    "cluster_distribution = np.bincount(track_clusters_balanced)\n",
    "print(\"Balanced cluster distribution:\", cluster_distribution)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6da8c2a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Centroid vectors for each cluster:\n",
      " for instance, the first one: \n",
      " [ 0.06953203 -0.01117762  0.01909637 -0.02113779  0.03288418 -0.00680944\n",
      " -0.01537758  0.02478403 -0.03482736 -0.036354  ]\n"
     ]
    }
   ],
   "source": [
    "centroid_vectors = clf.cluster_centers_\n",
    "print(\"Centroid vectors for each cluster:\\n for instance, the first one: \\n\", centroid_vectors[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "08a6fd83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "26707"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "track_id_to_cluster = {track_id: cluster for track_id, cluster in zip(session_track_matrix.columns, track_clusters_balanced)}\n",
    "len(track_id_to_cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0dc0cf5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2207913/146496183.py:9: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_20_filtered['F'] = df_20_filtered['track_id_clean'].apply(lambda x: track_to_cluster_vector(x, track_id_to_cluster, num_clusters))\n"
     ]
    }
   ],
   "source": [
    "def track_to_cluster_vector(track_id, track_id_to_cluster, num_clusters):\n",
    "    # Initialize a count vector with zeros for each cluster\n",
    "    cluster_vector = np.zeros(num_clusters, dtype=int)\n",
    "    # Set the value at the cluster index to 1\n",
    "    cluster_index = track_id_to_cluster.get(track_id)\n",
    "    cluster_vector[cluster_index] = 1\n",
    "    return cluster_vector\n",
    "\n",
    "df_20_filtered['F'] = df_20_filtered['track_id_clean'].apply(lambda x: track_to_cluster_vector(x, track_id_to_cluster, num_clusters))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "116a82e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2207913/1198334475.py:5: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_20_filtered['F_vector'] = df_20_filtered['track_id_clean'].apply(lambda x: get_centroid_vector(x, track_id_to_cluster, centroid_vectors))\n"
     ]
    }
   ],
   "source": [
    "def get_centroid_vector(track_id, track_id_to_cluster, centroid_vectors):\n",
    "    cluster_index = track_id_to_cluster.get(track_id)\n",
    "    return centroid_vectors[cluster_index]\n",
    "\n",
    "df_20_filtered['F_vector'] = df_20_filtered['track_id_clean'].apply(lambda x: get_centroid_vector(x, track_id_to_cluster, centroid_vectors))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8fac6844",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>session_position</th>\n",
       "      <th>session_length</th>\n",
       "      <th>track_id_clean</th>\n",
       "      <th>skip_3</th>\n",
       "      <th>F</th>\n",
       "      <th>F_vector</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8eecd43d-3146-47ab-9ff6-3d300abe3216</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_47040682-2572-4b22-a3c3-7a3426634371</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06953126222753003, -0.011178203525205396, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c3e3086e-b38f-467b-b05a-9aa511a8aef5</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_22435160-213e-4aac-903e-9e60d643d761</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8163fb61-4a3f-4b45-9b3f-f125da29c86b</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e241fa6f-b682-4114-98ad-70f79a0dd99b</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]</td>\n",
       "      <td>[0.05367344872695791, -0.0060483777861048615, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_7acef85a-fc13-4038-afe6-00d31c373f4b</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_2bb41382-4b9c-4d5e-98ee-b321b675a6ff</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_758b880c-dbb8-403e-b7c5-e3932f707aa9</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_27da1c72-64d5-432b-b520-2071e99e8ffe</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 1, 0, 0, 0]</td>\n",
       "      <td>[0.05197004493717251, -0.011174131555338772, -...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e8041aa0-11f1-4147-bdc3-0b889ba101f2</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8b30c326-5f53-4e99-adb6-994704a52f05</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_a633d9e8-1342-4dbf-83de-86e87901efe0</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_35eaec20-2bdb-4c38-8f10-e7b3c42ab688</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e0ce1b77-c912-4bd7-af66-caae072f3b7b</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_f3c4c677-5b5d-49cf-b2ba-8028d63a29b4</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_cbd23b21-2693-49cc-872f-c061de8c93c4</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_925fe3a1-e8be-4062-9dd4-2a41857e308e</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_32598e62-6456-47a8-8c1e-b93f7c68a2b6</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c7fc8984-9666-48f9-9081-ffc4a0a1310b</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 session_id  session_position  session_length  \\\n",
       "0   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 1              20   \n",
       "1   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 2              20   \n",
       "2   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 3              20   \n",
       "3   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 4              20   \n",
       "4   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 5              20   \n",
       "5   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 6              20   \n",
       "6   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 7              20   \n",
       "7   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 8              20   \n",
       "8   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 9              20   \n",
       "9   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                10              20   \n",
       "10  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 1              20   \n",
       "11  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 2              20   \n",
       "12  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 3              20   \n",
       "13  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 4              20   \n",
       "14  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 5              20   \n",
       "15  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 6              20   \n",
       "16  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 7              20   \n",
       "17  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 8              20   \n",
       "18  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 9              20   \n",
       "19  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                10              20   \n",
       "\n",
       "                            track_id_clean  skip_3  \\\n",
       "0   t_8eecd43d-3146-47ab-9ff6-3d300abe3216    True   \n",
       "1   t_47040682-2572-4b22-a3c3-7a3426634371    True   \n",
       "2   t_c3e3086e-b38f-467b-b05a-9aa511a8aef5    True   \n",
       "3   t_22435160-213e-4aac-903e-9e60d643d761    True   \n",
       "4   t_8163fb61-4a3f-4b45-9b3f-f125da29c86b    True   \n",
       "5   t_e241fa6f-b682-4114-98ad-70f79a0dd99b    True   \n",
       "6   t_7acef85a-fc13-4038-afe6-00d31c373f4b    True   \n",
       "7   t_2bb41382-4b9c-4d5e-98ee-b321b675a6ff    True   \n",
       "8   t_758b880c-dbb8-403e-b7c5-e3932f707aa9    True   \n",
       "9   t_27da1c72-64d5-432b-b520-2071e99e8ffe   False   \n",
       "10  t_e8041aa0-11f1-4147-bdc3-0b889ba101f2   False   \n",
       "11  t_8b30c326-5f53-4e99-adb6-994704a52f05   False   \n",
       "12  t_a633d9e8-1342-4dbf-83de-86e87901efe0   False   \n",
       "13  t_35eaec20-2bdb-4c38-8f10-e7b3c42ab688    True   \n",
       "14  t_e0ce1b77-c912-4bd7-af66-caae072f3b7b    True   \n",
       "15  t_f3c4c677-5b5d-49cf-b2ba-8028d63a29b4    True   \n",
       "16  t_cbd23b21-2693-49cc-872f-c061de8c93c4    True   \n",
       "17  t_925fe3a1-e8be-4062-9dd4-2a41857e308e   False   \n",
       "18  t_32598e62-6456-47a8-8c1e-b93f7c68a2b6    True   \n",
       "19  t_c7fc8984-9666-48f9-9081-ffc4a0a1310b   False   \n",
       "\n",
       "                                 F  \\\n",
       "0   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "1   [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]   \n",
       "2   [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "3   [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "4   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "5   [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]   \n",
       "6   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "7   [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "8   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "9   [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]   \n",
       "10  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "11  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "12  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "13  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "14  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "15  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "16  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "17  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "18  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "19  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "\n",
       "                                             F_vector  \n",
       "0   [0.0695320313643396, -0.011177622068879926, 0....  \n",
       "1   [0.06953126222753003, -0.011178203525205396, 0...  \n",
       "2   [0.06953172399542647, -0.011177516303562769, 0...  \n",
       "3   [0.0695004329733046, -0.011181794135149432, 0....  \n",
       "4   [0.0695320313643396, -0.011177622068879926, 0....  \n",
       "5   [0.05367344872695791, -0.0060483777861048615, ...  \n",
       "6   [0.0695320313643396, -0.011177622068879926, 0....  \n",
       "7   [0.0695004329733046, -0.011181794135149432, 0....  \n",
       "8   [0.0695320313643396, -0.011177622068879926, 0....  \n",
       "9   [0.05197004493717251, -0.011174131555338772, -...  \n",
       "10  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "11  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "12  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "13  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "14  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "15  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "16  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "17  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "18  [-0.5843973841788984, 0.08490661836005894, -0....  \n",
       "19  [-0.5843973841788984, 0.08490661836005894, -0....  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_20_filtered.head(20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03cc83ef",
   "metadata": {},
   "source": [
    "#### We now create the X value based on the skip_3. In the official documentation, the skip_3 is denoted as \"**Skip3: Boolean indicating if most of the track was played**\", to reflect the value, we consider converting it to a reward, say if it is True, the reward is 1 and if not the reward is 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e82ed1f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2207913/1575398399.py:8: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_20_filtered['skip_value'] = df_20_filtered.apply(assign_skip_value, axis=1)\n"
     ]
    }
   ],
   "source": [
    "def assign_skip_value(row):\n",
    "    base_value = 1 if row['skip_3'] else 0\n",
    "    # noise = np.random.normal(loc=0, scale=0.5)\n",
    "    return base_value\n",
    "\n",
    "set_seed(global_seed)\n",
    "# df_20_filtered['skip_value'] = df_20_filtered.apply(assign_skip_value_with_noise, axis=1)\n",
    "df_20_filtered['skip_value'] = df_20_filtered.apply(assign_skip_value, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "40bd5561",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>session_position</th>\n",
       "      <th>session_length</th>\n",
       "      <th>track_id_clean</th>\n",
       "      <th>skip_3</th>\n",
       "      <th>F</th>\n",
       "      <th>F_vector</th>\n",
       "      <th>skip_value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8eecd43d-3146-47ab-9ff6-3d300abe3216</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_47040682-2572-4b22-a3c3-7a3426634371</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06953126222753003, -0.011178203525205396, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c3e3086e-b38f-467b-b05a-9aa511a8aef5</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_22435160-213e-4aac-903e-9e60d643d761</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8163fb61-4a3f-4b45-9b3f-f125da29c86b</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e241fa6f-b682-4114-98ad-70f79a0dd99b</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]</td>\n",
       "      <td>[0.05367344872695791, -0.0060483777861048615, ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_7acef85a-fc13-4038-afe6-00d31c373f4b</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_2bb41382-4b9c-4d5e-98ee-b321b675a6ff</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_758b880c-dbb8-403e-b7c5-e3932f707aa9</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_27da1c72-64d5-432b-b520-2071e99e8ffe</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 1, 0, 0, 0]</td>\n",
       "      <td>[0.05197004493717251, -0.011174131555338772, -...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e8041aa0-11f1-4147-bdc3-0b889ba101f2</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8b30c326-5f53-4e99-adb6-994704a52f05</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_a633d9e8-1342-4dbf-83de-86e87901efe0</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_35eaec20-2bdb-4c38-8f10-e7b3c42ab688</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e0ce1b77-c912-4bd7-af66-caae072f3b7b</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_f3c4c677-5b5d-49cf-b2ba-8028d63a29b4</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_cbd23b21-2693-49cc-872f-c061de8c93c4</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_925fe3a1-e8be-4062-9dd4-2a41857e308e</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_32598e62-6456-47a8-8c1e-b93f7c68a2b6</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>65_0025a3cc-485b-407e-9fc6-2b58fc07b957</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c7fc8984-9666-48f9-9081-ffc4a0a1310b</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_05e8145e-363e-4fec-a089-cbe30132df9f</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_9d487a00-0b49-4ab2-8faa-ecf5c9d4c8eb</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_9d92c8c1-f4b8-4bfd-95eb-2fe7eea0dde1</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_42454184-4358-4221-a2bf-e6847e1cd277</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06932655678539505, -0.011112521948605592, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_d1de1372-c06a-41d5-aaf5-cde7f6710269</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_bef4ba44-da3c-4af6-9826-43ee27489fc3</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_d4374dbf-dc65-4458-8d52-f5fb9298319e</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_e89c6128-7905-48e0-85dc-7a8f2221283a</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_d52e98fa-c226-446e-99f4-2041d3864d98</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>65_0026980c-c0de-46bd-860b-02b4050bf72c</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_0099502e-9e11-4edd-8673-f94b4cfeec26</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_01af8214-7b66-4b7e-95ac-2a9981808e11</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06144204097217993, -0.007169639929709159, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_6d72cff5-8f27-412c-bd62-b4dd0022eb13</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06144204097217993, -0.007169639929709159, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_6dc02fe6-4d78-414a-8439-31240ee02270</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_d26d6571-300b-4e1d-8e0a-1392fa8931c0</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06144204097217993, -0.007169639929709159, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_ab0572ba-637d-4460-a98b-43a9b7423382</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06144204097217993, -0.007169639929709159, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>t_17163c42-4b91-46a8-9666-965a3c231d3d</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>7</td>\n",
       "      <td>20</td>\n",
       "      <td>t_65efbe00-7405-4f8a-b19a-314bad4340de</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06144204097217993, -0.007169639929709159, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>t_307b7c86-7474-4d2f-a755-fe3f25f8f9bc</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>9</td>\n",
       "      <td>20</td>\n",
       "      <td>t_6068ee9d-cd8a-45ac-9c27-adac46e2069f</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06144204097217993, -0.007169639929709159, 0...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>65_00285fa4-2c95-45af-a57e-1b73de06b895</td>\n",
       "      <td>10</td>\n",
       "      <td>20</td>\n",
       "      <td>t_a157ba98-7eae-4e7d-99d5-ed760b2c0978</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]</td>\n",
       "      <td>[-0.5843973841788984, 0.08490661836005894, -0....</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 session_id  session_position  session_length  \\\n",
       "0   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 1              20   \n",
       "1   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 2              20   \n",
       "2   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 3              20   \n",
       "3   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 4              20   \n",
       "4   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 5              20   \n",
       "5   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 6              20   \n",
       "6   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 7              20   \n",
       "7   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 8              20   \n",
       "8   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 9              20   \n",
       "9   65_0008dac3-8379-4369-90a1-b7ceccbb2d51                10              20   \n",
       "10  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 1              20   \n",
       "11  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 2              20   \n",
       "12  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 3              20   \n",
       "13  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 4              20   \n",
       "14  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 5              20   \n",
       "15  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 6              20   \n",
       "16  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 7              20   \n",
       "17  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 8              20   \n",
       "18  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                 9              20   \n",
       "19  65_0025a3cc-485b-407e-9fc6-2b58fc07b957                10              20   \n",
       "20  65_0026980c-c0de-46bd-860b-02b4050bf72c                 1              20   \n",
       "21  65_0026980c-c0de-46bd-860b-02b4050bf72c                 2              20   \n",
       "22  65_0026980c-c0de-46bd-860b-02b4050bf72c                 3              20   \n",
       "23  65_0026980c-c0de-46bd-860b-02b4050bf72c                 4              20   \n",
       "24  65_0026980c-c0de-46bd-860b-02b4050bf72c                 5              20   \n",
       "25  65_0026980c-c0de-46bd-860b-02b4050bf72c                 6              20   \n",
       "26  65_0026980c-c0de-46bd-860b-02b4050bf72c                 7              20   \n",
       "27  65_0026980c-c0de-46bd-860b-02b4050bf72c                 8              20   \n",
       "28  65_0026980c-c0de-46bd-860b-02b4050bf72c                 9              20   \n",
       "29  65_0026980c-c0de-46bd-860b-02b4050bf72c                10              20   \n",
       "30  65_00285fa4-2c95-45af-a57e-1b73de06b895                 1              20   \n",
       "31  65_00285fa4-2c95-45af-a57e-1b73de06b895                 2              20   \n",
       "32  65_00285fa4-2c95-45af-a57e-1b73de06b895                 3              20   \n",
       "33  65_00285fa4-2c95-45af-a57e-1b73de06b895                 4              20   \n",
       "34  65_00285fa4-2c95-45af-a57e-1b73de06b895                 5              20   \n",
       "35  65_00285fa4-2c95-45af-a57e-1b73de06b895                 6              20   \n",
       "36  65_00285fa4-2c95-45af-a57e-1b73de06b895                 7              20   \n",
       "37  65_00285fa4-2c95-45af-a57e-1b73de06b895                 8              20   \n",
       "38  65_00285fa4-2c95-45af-a57e-1b73de06b895                 9              20   \n",
       "39  65_00285fa4-2c95-45af-a57e-1b73de06b895                10              20   \n",
       "\n",
       "                            track_id_clean  skip_3  \\\n",
       "0   t_8eecd43d-3146-47ab-9ff6-3d300abe3216    True   \n",
       "1   t_47040682-2572-4b22-a3c3-7a3426634371    True   \n",
       "2   t_c3e3086e-b38f-467b-b05a-9aa511a8aef5    True   \n",
       "3   t_22435160-213e-4aac-903e-9e60d643d761    True   \n",
       "4   t_8163fb61-4a3f-4b45-9b3f-f125da29c86b    True   \n",
       "5   t_e241fa6f-b682-4114-98ad-70f79a0dd99b    True   \n",
       "6   t_7acef85a-fc13-4038-afe6-00d31c373f4b    True   \n",
       "7   t_2bb41382-4b9c-4d5e-98ee-b321b675a6ff    True   \n",
       "8   t_758b880c-dbb8-403e-b7c5-e3932f707aa9    True   \n",
       "9   t_27da1c72-64d5-432b-b520-2071e99e8ffe   False   \n",
       "10  t_e8041aa0-11f1-4147-bdc3-0b889ba101f2   False   \n",
       "11  t_8b30c326-5f53-4e99-adb6-994704a52f05   False   \n",
       "12  t_a633d9e8-1342-4dbf-83de-86e87901efe0   False   \n",
       "13  t_35eaec20-2bdb-4c38-8f10-e7b3c42ab688    True   \n",
       "14  t_e0ce1b77-c912-4bd7-af66-caae072f3b7b    True   \n",
       "15  t_f3c4c677-5b5d-49cf-b2ba-8028d63a29b4    True   \n",
       "16  t_cbd23b21-2693-49cc-872f-c061de8c93c4    True   \n",
       "17  t_925fe3a1-e8be-4062-9dd4-2a41857e308e   False   \n",
       "18  t_32598e62-6456-47a8-8c1e-b93f7c68a2b6    True   \n",
       "19  t_c7fc8984-9666-48f9-9081-ffc4a0a1310b   False   \n",
       "20  t_05e8145e-363e-4fec-a089-cbe30132df9f    True   \n",
       "21  t_9d487a00-0b49-4ab2-8faa-ecf5c9d4c8eb    True   \n",
       "22  t_9d92c8c1-f4b8-4bfd-95eb-2fe7eea0dde1    True   \n",
       "23  t_42454184-4358-4221-a2bf-e6847e1cd277    True   \n",
       "24  t_d1de1372-c06a-41d5-aaf5-cde7f6710269    True   \n",
       "25  t_bef4ba44-da3c-4af6-9826-43ee27489fc3    True   \n",
       "26  t_d4374dbf-dc65-4458-8d52-f5fb9298319e    True   \n",
       "27  t_e89c6128-7905-48e0-85dc-7a8f2221283a    True   \n",
       "28  t_d52e98fa-c226-446e-99f4-2041d3864d98    True   \n",
       "29  t_0099502e-9e11-4edd-8673-f94b4cfeec26    True   \n",
       "30  t_01af8214-7b66-4b7e-95ac-2a9981808e11    True   \n",
       "31  t_6d72cff5-8f27-412c-bd62-b4dd0022eb13    True   \n",
       "32  t_6dc02fe6-4d78-414a-8439-31240ee02270   False   \n",
       "33  t_d26d6571-300b-4e1d-8e0a-1392fa8931c0    True   \n",
       "34  t_ab0572ba-637d-4460-a98b-43a9b7423382    True   \n",
       "35  t_17163c42-4b91-46a8-9666-965a3c231d3d    True   \n",
       "36  t_65efbe00-7405-4f8a-b19a-314bad4340de    True   \n",
       "37  t_307b7c86-7474-4d2f-a755-fe3f25f8f9bc    True   \n",
       "38  t_6068ee9d-cd8a-45ac-9c27-adac46e2069f    True   \n",
       "39  t_a157ba98-7eae-4e7d-99d5-ed760b2c0978   False   \n",
       "\n",
       "                                 F  \\\n",
       "0   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "1   [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]   \n",
       "2   [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "3   [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "4   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "5   [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]   \n",
       "6   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "7   [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "8   [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "9   [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]   \n",
       "10  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "11  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "12  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "13  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "14  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "15  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "16  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "17  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "18  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "19  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "20  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "21  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "22  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "23  [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]   \n",
       "24  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "25  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "26  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "27  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "28  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "29  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "30  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "31  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "32  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "33  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "34  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "35  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "36  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "37  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "38  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "39  [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]   \n",
       "\n",
       "                                             F_vector  skip_value  \n",
       "0   [0.0695320313643396, -0.011177622068879926, 0....           1  \n",
       "1   [0.06953126222753003, -0.011178203525205396, 0...           1  \n",
       "2   [0.06953172399542647, -0.011177516303562769, 0...           1  \n",
       "3   [0.0695004329733046, -0.011181794135149432, 0....           1  \n",
       "4   [0.0695320313643396, -0.011177622068879926, 0....           1  \n",
       "5   [0.05367344872695791, -0.0060483777861048615, ...           1  \n",
       "6   [0.0695320313643396, -0.011177622068879926, 0....           1  \n",
       "7   [0.0695004329733046, -0.011181794135149432, 0....           1  \n",
       "8   [0.0695320313643396, -0.011177622068879926, 0....           1  \n",
       "9   [0.05197004493717251, -0.011174131555338772, -...           0  \n",
       "10  [-0.5843973841788984, 0.08490661836005894, -0....           0  \n",
       "11  [-0.5843973841788984, 0.08490661836005894, -0....           0  \n",
       "12  [-0.5843973841788984, 0.08490661836005894, -0....           0  \n",
       "13  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "14  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "15  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "16  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "17  [-0.5843973841788984, 0.08490661836005894, -0....           0  \n",
       "18  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "19  [-0.5843973841788984, 0.08490661836005894, -0....           0  \n",
       "20  [0.0695004329733046, -0.011181794135149432, 0....           1  \n",
       "21  [0.0695320313643396, -0.011177622068879926, 0....           1  \n",
       "22  [0.0695320313643396, -0.011177622068879926, 0....           1  \n",
       "23  [0.06932655678539505, -0.011112521948605592, 0...           1  \n",
       "24  [0.06953172399542647, -0.011177516303562769, 0...           1  \n",
       "25  [0.0695004329733046, -0.011181794135149432, 0....           1  \n",
       "26  [0.06953172399542647, -0.011177516303562769, 0...           1  \n",
       "27  [0.06953172399542647, -0.011177516303562769, 0...           1  \n",
       "28  [0.06953172399542647, -0.011177516303562769, 0...           1  \n",
       "29  [0.0695004329733046, -0.011181794135149432, 0....           1  \n",
       "30  [0.06144204097217993, -0.007169639929709159, 0...           1  \n",
       "31  [0.06144204097217993, -0.007169639929709159, 0...           1  \n",
       "32  [-0.5843973841788984, 0.08490661836005894, -0....           0  \n",
       "33  [0.06144204097217993, -0.007169639929709159, 0...           1  \n",
       "34  [0.06144204097217993, -0.007169639929709159, 0...           1  \n",
       "35  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "36  [0.06144204097217993, -0.007169639929709159, 0...           1  \n",
       "37  [-0.5843973841788984, 0.08490661836005894, -0....           1  \n",
       "38  [0.06144204097217993, -0.007169639929709159, 0...           1  \n",
       "39  [-0.5843973841788984, 0.08490661836005894, -0....           0  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_20_filtered.head(40)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eacf2da8",
   "metadata": {},
   "source": [
    "#### We further augment the skip_3 value as a aggregation effect over the sequence with a decay factor. Each aggregation is applied to the specific session_id and we name this new aggregation effect X. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "388a72f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_X(group):\n",
    "    # Compute the cumulative sum of X up to the current time point\n",
    "    X_cumsum = group['skip_value'].cumsum()\n",
    "    # Add Gaussian noise to the cumulative sum\n",
    "    X_prime = X_cumsum + np.random.normal(0, 0.5, size=len(X_cumsum))\n",
    "    # Add X' to the group\n",
    "    group[\"X\"] = X_prime\n",
    "    return group\n",
    "\n",
    "set_seed(global_seed)\n",
    "df_20_filtered = df_20_filtered.groupby('session_id').apply(compute_X).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "dbec2bef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>session_position</th>\n",
       "      <th>session_length</th>\n",
       "      <th>track_id_clean</th>\n",
       "      <th>skip_3</th>\n",
       "      <th>F</th>\n",
       "      <th>F_vector</th>\n",
       "      <th>skip_value</th>\n",
       "      <th>X</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8eecd43d-3146-47ab-9ff6-3d300abe3216</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "      <td>1.248357</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>t_47040682-2572-4b22-a3c3-7a3426634371</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.06953126222753003, -0.011178203525205396, 0...</td>\n",
       "      <td>1</td>\n",
       "      <td>1.930868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>t_c3e3086e-b38f-467b-b05a-9aa511a8aef5</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>1</td>\n",
       "      <td>3.323844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>t_22435160-213e-4aac-903e-9e60d643d761</td>\n",
       "      <td>True</td>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>1</td>\n",
       "      <td>4.761515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>5</td>\n",
       "      <td>20</td>\n",
       "      <td>t_8163fb61-4a3f-4b45-9b3f-f125da29c86b</td>\n",
       "      <td>True</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>1</td>\n",
       "      <td>4.882923</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                session_id  session_position  session_length  \\\n",
       "0  65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 1              20   \n",
       "1  65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 2              20   \n",
       "2  65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 3              20   \n",
       "3  65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 4              20   \n",
       "4  65_0008dac3-8379-4369-90a1-b7ceccbb2d51                 5              20   \n",
       "\n",
       "                           track_id_clean  skip_3  \\\n",
       "0  t_8eecd43d-3146-47ab-9ff6-3d300abe3216    True   \n",
       "1  t_47040682-2572-4b22-a3c3-7a3426634371    True   \n",
       "2  t_c3e3086e-b38f-467b-b05a-9aa511a8aef5    True   \n",
       "3  t_22435160-213e-4aac-903e-9e60d643d761    True   \n",
       "4  t_8163fb61-4a3f-4b45-9b3f-f125da29c86b    True   \n",
       "\n",
       "                                F  \\\n",
       "0  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "1  [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]   \n",
       "2  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]   \n",
       "3  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]   \n",
       "4  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]   \n",
       "\n",
       "                                            F_vector  skip_value         X  \n",
       "0  [0.0695320313643396, -0.011177622068879926, 0....           1  1.248357  \n",
       "1  [0.06953126222753003, -0.011178203525205396, 0...           1  1.930868  \n",
       "2  [0.06953172399542647, -0.011177516303562769, 0...           1  3.323844  \n",
       "3  [0.0695004329733046, -0.011181794135149432, 0....           1  4.761515  \n",
       "4  [0.0695320313643396, -0.011177622068879926, 0....           1  4.882923  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_20_filtered.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9fa1758e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>F</th>\n",
       "      <th>X</th>\n",
       "      <th>F_vector</th>\n",
       "      <th>session_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.248357</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.930868</td>\n",
       "      <td>[0.06953126222753003, -0.011178203525205396, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>3.323844</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>4.761515</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>4.882923</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                F         X  \\\n",
       "0  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]  1.248357   \n",
       "1  [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]  1.930868   \n",
       "2  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]  3.323844   \n",
       "3  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]  4.761515   \n",
       "4  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]  4.882923   \n",
       "\n",
       "                                            F_vector  \\\n",
       "0  [0.0695320313643396, -0.011177622068879926, 0....   \n",
       "1  [0.06953126222753003, -0.011178203525205396, 0...   \n",
       "2  [0.06953172399542647, -0.011177516303562769, 0...   \n",
       "3  [0.0695004329733046, -0.011181794135149432, 0....   \n",
       "4  [0.0695320313643396, -0.011177622068879926, 0....   \n",
       "\n",
       "                                session_id  \n",
       "0  65_0008dac3-8379-4369-90a1-b7ceccbb2d51  \n",
       "1  65_0008dac3-8379-4369-90a1-b7ceccbb2d51  \n",
       "2  65_0008dac3-8379-4369-90a1-b7ceccbb2d51  \n",
       "3  65_0008dac3-8379-4369-90a1-b7ceccbb2d51  \n",
       "4  65_0008dac3-8379-4369-90a1-b7ceccbb2d51  "
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_final = df_20_filtered[['F', 'X', 'F_vector', 'session_id']]\n",
    "df_final.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8ceab04c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2207913/2426222498.py:7: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_final['Z'] = df_final['session_id'].map(session_to_Z)\n"
     ]
    }
   ],
   "source": [
    "z_dim = 5\n",
    "\n",
    "unique_sessions = df_final['session_id'].unique()\n",
    "session_to_Z = {session: session_embeddings[i][:z_dim] for i, session in enumerate(unique_sessions)}\n",
    "\n",
    "\n",
    "df_final['Z'] = df_final['session_id'].map(session_to_Z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "09a82951",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>F</th>\n",
       "      <th>X</th>\n",
       "      <th>F_vector</th>\n",
       "      <th>session_id</th>\n",
       "      <th>Z</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.248357</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.930868</td>\n",
       "      <td>[0.06953126222753003, -0.011178203525205396, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]</td>\n",
       "      <td>3.323844</td>\n",
       "      <td>[0.06953172399542647, -0.011177516303562769, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]</td>\n",
       "      <td>4.761515</td>\n",
       "      <td>[0.0695004329733046, -0.011181794135149432, 0....</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>4.882923</td>\n",
       "      <td>[0.0695320313643396, -0.011177622068879926, 0....</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                F         X  \\\n",
       "0  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]  1.248357   \n",
       "1  [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]  1.930868   \n",
       "2  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]  3.323844   \n",
       "3  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]  4.761515   \n",
       "4  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]  4.882923   \n",
       "\n",
       "                                            F_vector  \\\n",
       "0  [0.0695320313643396, -0.011177622068879926, 0....   \n",
       "1  [0.06953126222753003, -0.011178203525205396, 0...   \n",
       "2  [0.06953172399542647, -0.011177516303562769, 0...   \n",
       "3  [0.0695004329733046, -0.011181794135149432, 0....   \n",
       "4  [0.0695320313643396, -0.011177622068879926, 0....   \n",
       "\n",
       "                                session_id  \\\n",
       "0  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "1  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "2  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "3  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "4  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "\n",
       "                                                   Z  \n",
       "0  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "1  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "2  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "3  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "4  [0.3865182998978819, -0.04937550520932488, 0.0...  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_final.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "dc82a356",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have a time series of 6000 X 10 X 10. and z of dimension 5\n"
     ]
    }
   ],
   "source": [
    "N = df_final['session_id'].nunique()\n",
    "T = df_final.groupby('session_id').size().iloc[0]\n",
    "l = len(df_final['F'].iloc[0])\n",
    "z_dim = len(df_final['Z'].iloc[0])\n",
    "\n",
    "print(f'We have a time series of {N} X {T} X {l}. and z of dimension {z_dim}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "2c0af203",
   "metadata": {},
   "outputs": [],
   "source": [
    "F_array = np.zeros((N, T, l))\n",
    "F_vector_array = np.zeros((N, T, l))\n",
    "X_array = np.zeros((N, T))\n",
    "Z_array = np.zeros((N, z_dim))\n",
    "\n",
    "\n",
    "for i, (session_id, group) in enumerate(df_final.groupby('session_id')):\n",
    "    # Ensure the group is sorted\n",
    "    group = group.sort_index()\n",
    "    F_array[i, :len(group), :] = np.array(group['F'].tolist())\n",
    "    F_vector_array[i, :len(group), :] = np.array(group['F_vector'].tolist())\n",
    "    X_array[i, :len(group)] = group['X'].values\n",
    "    Z_array[i] = group['Z'].iloc[0]\n",
    "\n",
    "# Z_array = session_embeddings[:, :z_dim]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "999a68fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('./data/x.npy', X_array)\n",
    "np.save('./data/f_vec.npy', F_vector_array)\n",
    "np.save('./data/f.npy', F_array)\n",
    "np.save('./data/z.npy', Z_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d3e9af61",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2207913/3838865470.py:3: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_final['F'] = df_final['F'].apply(lambda x: json.dumps(x.tolist()))\n",
      "/tmp/ipykernel_2207913/3838865470.py:4: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_final['F_vector'] = df_final['F_vector'].apply(lambda x: json.dumps(x.tolist()))\n",
      "/tmp/ipykernel_2207913/3838865470.py:5: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_final['Z'] = df_final['Z'].apply(lambda x: json.dumps(x.tolist()))\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "df_final['F'] = df_final['F'].apply(lambda x: json.dumps(x.tolist()))\n",
    "df_final['F_vector'] = df_final['F_vector'].apply(lambda x: json.dumps(x.tolist()))\n",
    "df_final['Z'] = df_final['Z'].apply(lambda x: json.dumps(x.tolist()))\n",
    "\n",
    "df_final.to_csv('spotify_data.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
