{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### This is for loading the filtered data in case the original file is either too large or can not be found"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "import ast\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>F</th>\n",
       "      <th>X</th>\n",
       "      <th>F_vector</th>\n",
       "      <th>session_id</th>\n",
       "      <th>Z</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.248357</td>\n",
       "      <td>[0.06952813690749203, -0.011178137654304811, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.930868</td>\n",
       "      <td>[0.06952381845489436, -0.0111804461479361, 0.0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>3.323844</td>\n",
       "      <td>[0.06952657157587336, -0.01118143003243, 0.019...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>4.761515</td>\n",
       "      <td>[0.06952381845489436, -0.0111804461479361, 0.0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]</td>\n",
       "      <td>4.882923</td>\n",
       "      <td>[0.06946601984841431, -0.011168818418563512, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               F         X  \\\n",
       "0  [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]  1.248357   \n",
       "1  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]  1.930868   \n",
       "2  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]  3.323844   \n",
       "3  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]  4.761515   \n",
       "4  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]  4.882923   \n",
       "\n",
       "                                            F_vector  \\\n",
       "0  [0.06952813690749203, -0.011178137654304811, 0...   \n",
       "1  [0.06952381845489436, -0.0111804461479361, 0.0...   \n",
       "2  [0.06952657157587336, -0.01118143003243, 0.019...   \n",
       "3  [0.06952381845489436, -0.0111804461479361, 0.0...   \n",
       "4  [0.06946601984841431, -0.011168818418563512, 0...   \n",
       "\n",
       "                                session_id  \\\n",
       "0  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "1  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "2  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "3  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "4  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "\n",
       "                                                   Z  \n",
       "0  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "1  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "2  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "3  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "4  [0.3865182998978819, -0.04937550520932488, 0.0...  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_final = pd.read_csv(filepath_or_buffer='./spotify_data.csv')\n",
    "df_final.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>F</th>\n",
       "      <th>X</th>\n",
       "      <th>F_vector</th>\n",
       "      <th>session_id</th>\n",
       "      <th>Z</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.248357</td>\n",
       "      <td>[0.06952813690749203, -0.011178137654304811, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>1.930868</td>\n",
       "      <td>[0.06952381845489436, -0.0111804461479361, 0.0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>3.323844</td>\n",
       "      <td>[0.06952657157587336, -0.01118143003243, 0.019...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]</td>\n",
       "      <td>4.761515</td>\n",
       "      <td>[0.06952381845489436, -0.0111804461479361, 0.0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]</td>\n",
       "      <td>4.882923</td>\n",
       "      <td>[0.06946601984841431, -0.011168818418563512, 0...</td>\n",
       "      <td>65_0008dac3-8379-4369-90a1-b7ceccbb2d51</td>\n",
       "      <td>[0.3865182998978819, -0.04937550520932488, 0.0...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               F         X  \\\n",
       "0  [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]  1.248357   \n",
       "1  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]  1.930868   \n",
       "2  [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]  3.323844   \n",
       "3  [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]  4.761515   \n",
       "4  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]  4.882923   \n",
       "\n",
       "                                            F_vector  \\\n",
       "0  [0.06952813690749203, -0.011178137654304811, 0...   \n",
       "1  [0.06952381845489436, -0.0111804461479361, 0.0...   \n",
       "2  [0.06952657157587336, -0.01118143003243, 0.019...   \n",
       "3  [0.06952381845489436, -0.0111804461479361, 0.0...   \n",
       "4  [0.06946601984841431, -0.011168818418563512, 0...   \n",
       "\n",
       "                                session_id  \\\n",
       "0  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "1  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "2  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "3  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "4  65_0008dac3-8379-4369-90a1-b7ceccbb2d51   \n",
       "\n",
       "                                                   Z  \n",
       "0  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "1  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "2  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "3  [0.3865182998978819, -0.04937550520932488, 0.0...  \n",
       "4  [0.3865182998978819, -0.04937550520932488, 0.0...  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_final['F'] = df_final['F'].apply(json.loads)\n",
    "df_final['F_vector'] = df_final['F_vector'].apply(json.loads)\n",
    "df_final['Z'] = df_final['Z'].apply(json.loads)\n",
    "\n",
    "df_final.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_final['F'] = df_final['F'].apply(np.array)\n",
    "df_final['F_vector'] = df_final['F_vector'].apply(np.array)\n",
    "df_final['Z'] = df_final['Z'].apply(np.array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have a time series of 6000 X 10 X 15. and z of dimension 5\n"
     ]
    }
   ],
   "source": [
    "N = df_final['session_id'].nunique()\n",
    "T = df_final.groupby('session_id').size().iloc[0]\n",
    "l = len(df_final['F'].iloc[0])\n",
    "z_dim = len(df_final['Z'].iloc[0])\n",
    "\n",
    "print(f'We have a time series of {N} X {T} X {l}. and z of dimension {z_dim}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "F_array = np.zeros((N, T, l))\n",
    "F_vector_array = np.zeros((N, T, l))\n",
    "X_array = np.zeros((N, T))\n",
    "Z_array = np.zeros((N, z_dim))\n",
    "\n",
    "\n",
    "for i, (session_id, group) in enumerate(df_final.groupby('session_id')):\n",
    "    # Ensure the group is sorted\n",
    "    group = group.sort_index()\n",
    "    F_array[i, :len(group), :] = np.array(group['F'].tolist())\n",
    "    F_vector_array[i, :len(group), :] = np.array(group['F_vector'].tolist())\n",
    "    X_array[i, :len(group)] = group['X'].values\n",
    "    Z_array[i] = group['Z'].iloc[0]\n",
    "\n",
    "# Z_array = session_embeddings[:, :z_dim]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read X of shape (6000, 10), F of shape (6000, 10, 15), F_vector of shape (6000, 10, 15) and Z of shape (6000, 5)\n"
     ]
    }
   ],
   "source": [
    "print(f'Read X of shape {X_array.shape}, F of shape {F_array.shape}, F_vector of shape {F_vector_array.shape} and Z of shape {Z_array.shape}')\n",
    "\n",
    "np.save('./data/x.npy', X_array)\n",
    "np.save('./data/f_vec.npy', F_vector_array)\n",
    "np.save('./data/f.npy', F_array)\n",
    "np.save('./data/z.npy', Z_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F_array[0, :3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.06952814, -0.01117814,  0.01909687, -0.02114447,  0.032886  ,\n",
       "        -0.00681013, -0.01538108,  0.02479554, -0.03485195, -0.03635439,\n",
       "        -0.01490247,  0.04607297, -0.00842362,  0.00661841, -0.02079743],\n",
       "       [ 0.06952382, -0.01118045,  0.01910232, -0.02114167,  0.03288622,\n",
       "        -0.00680267, -0.01537809,  0.024803  , -0.03485931, -0.03635278,\n",
       "        -0.01490233,  0.04606793, -0.00841987,  0.00661901, -0.02080396],\n",
       "       [ 0.06952657, -0.01118143,  0.0190919 , -0.02113861,  0.03288055,\n",
       "        -0.00681184, -0.01537577,  0.02477858, -0.0348463 , -0.03632605,\n",
       "        -0.01490079,  0.04606742, -0.00842519,  0.00661654, -0.02078746]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F_vector_array[0, :3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.24835708, 1.93086785, 3.32384427])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_array[0, :3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.3865183 , -0.04937551,  0.07676365, -0.08082001,  0.12558457])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z_array[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spotify",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
