{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../../\"\n",
    "data=pd.read_csv(path + 'data/small_matrix.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4676570, 8)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7988155"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# find the biggest play_duration in the data\n",
    "data['play_duration'].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>video_id</th>\n",
       "      <th>play_duration</th>\n",
       "      <th>video_duration</th>\n",
       "      <th>time</th>\n",
       "      <th>date</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>watch_ratio</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>14</td>\n",
       "      <td>148</td>\n",
       "      <td>4381</td>\n",
       "      <td>6067</td>\n",
       "      <td>2020-07-05 05:27:48.378</td>\n",
       "      <td>20200705.0</td>\n",
       "      <td>1.593898e+09</td>\n",
       "      <td>0.722103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>14</td>\n",
       "      <td>183</td>\n",
       "      <td>11635</td>\n",
       "      <td>6100</td>\n",
       "      <td>2020-07-05 05:28:00.057</td>\n",
       "      <td>20200705.0</td>\n",
       "      <td>1.593898e+09</td>\n",
       "      <td>1.907377</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>14</td>\n",
       "      <td>3649</td>\n",
       "      <td>22422</td>\n",
       "      <td>10867</td>\n",
       "      <td>2020-07-05 05:29:09.479</td>\n",
       "      <td>20200705.0</td>\n",
       "      <td>1.593898e+09</td>\n",
       "      <td>2.063311</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>14</td>\n",
       "      <td>5262</td>\n",
       "      <td>4479</td>\n",
       "      <td>7908</td>\n",
       "      <td>2020-07-05 05:30:43.285</td>\n",
       "      <td>20200705.0</td>\n",
       "      <td>1.593898e+09</td>\n",
       "      <td>0.566388</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>14</td>\n",
       "      <td>8234</td>\n",
       "      <td>4602</td>\n",
       "      <td>11000</td>\n",
       "      <td>2020-07-05 05:35:43.459</td>\n",
       "      <td>20200705.0</td>\n",
       "      <td>1.593899e+09</td>\n",
       "      <td>0.418364</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676565</th>\n",
       "      <td>7162</td>\n",
       "      <td>2267</td>\n",
       "      <td>11908</td>\n",
       "      <td>5467</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.178160</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676566</th>\n",
       "      <td>7162</td>\n",
       "      <td>2065</td>\n",
       "      <td>11919</td>\n",
       "      <td>6067</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.964562</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676567</th>\n",
       "      <td>7162</td>\n",
       "      <td>1296</td>\n",
       "      <td>16690</td>\n",
       "      <td>19870</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.839960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676568</th>\n",
       "      <td>7162</td>\n",
       "      <td>4822</td>\n",
       "      <td>11862</td>\n",
       "      <td>24400</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.486148</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676569</th>\n",
       "      <td>7162</td>\n",
       "      <td>4364</td>\n",
       "      <td>2182</td>\n",
       "      <td>19367</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.112666</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4676570 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         user_id  video_id  play_duration  video_duration  \\\n",
       "0             14       148           4381            6067   \n",
       "1             14       183          11635            6100   \n",
       "2             14      3649          22422           10867   \n",
       "3             14      5262           4479            7908   \n",
       "4             14      8234           4602           11000   \n",
       "...          ...       ...            ...             ...   \n",
       "4676565     7162      2267          11908            5467   \n",
       "4676566     7162      2065          11919            6067   \n",
       "4676567     7162      1296          16690           19870   \n",
       "4676568     7162      4822          11862           24400   \n",
       "4676569     7162      4364           2182           19367   \n",
       "\n",
       "                            time        date     timestamp  watch_ratio  \n",
       "0        2020-07-05 05:27:48.378  20200705.0  1.593898e+09     0.722103  \n",
       "1        2020-07-05 05:28:00.057  20200705.0  1.593898e+09     1.907377  \n",
       "2        2020-07-05 05:29:09.479  20200705.0  1.593898e+09     2.063311  \n",
       "3        2020-07-05 05:30:43.285  20200705.0  1.593898e+09     0.566388  \n",
       "4        2020-07-05 05:35:43.459  20200705.0  1.593899e+09     0.418364  \n",
       "...                          ...         ...           ...          ...  \n",
       "4676565                      NaN         NaN           NaN     2.178160  \n",
       "4676566                      NaN         NaN           NaN     1.964562  \n",
       "4676567                      NaN         NaN           NaN     0.839960  \n",
       "4676568                      NaN         NaN           NaN     0.486148  \n",
       "4676569                      NaN         NaN           NaN     0.112666  \n",
       "\n",
       "[4676570 rows x 8 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "         user_id  video_id  play_duration  video_duration  \\\n",
      "1             14       183          11635            6100   \n",
      "5             14      6789           8607           13267   \n",
      "7             14       175          11640           46514   \n",
      "8             14      1973           4572            7400   \n",
      "9             14       171           8518            5217   \n",
      "...          ...       ...            ...             ...   \n",
      "4676561     7162      4634          16653            8900   \n",
      "4676562     7162      9138           2218           10240   \n",
      "4676563     7162      7736          11948          140156   \n",
      "4676564     7162       530          11906           15866   \n",
      "4676565     7162      2267          11908            5467   \n",
      "\n",
      "                            time        date     timestamp  watch_ratio  \n",
      "1        2020-07-05 05:28:00.057  20200705.0  1.593898e+09     1.907377  \n",
      "5        2020-07-05 05:36:00.773  20200705.0  1.593899e+09     0.648753  \n",
      "7        2020-07-05 05:49:27.965  20200705.0  1.593899e+09     0.250247  \n",
      "8        2020-07-05 05:49:41.762  20200705.0  1.593899e+09     0.617838  \n",
      "9        2020-07-05 05:57:26.581  20200705.0  1.593900e+09     1.632739  \n",
      "...                          ...         ...           ...          ...  \n",
      "4676561                      NaN         NaN           NaN     1.871124  \n",
      "4676562                      NaN         NaN           NaN     0.216602  \n",
      "4676563                      NaN         NaN           NaN     0.085248  \n",
      "4676564                      NaN         NaN           NaN     0.750410  \n",
      "4676565                      NaN         NaN           NaN     2.178160  \n",
      "\n",
      "[2909482 rows x 8 columns]\n"
     ]
    }
   ],
   "source": [
    "# Group data by user and list all items each user has interacted with\n",
    "user_items = data.groupby('user_id')['video_id'].apply(set)\n",
    "\n",
    "# Find common items that all users have interacted with\n",
    "common_items = set.intersection(*user_items)\n",
    "\n",
    "# Filter the original data to only include interactions involving these common items\n",
    "filtered_data = data[data['video_id'].isin(common_items)]\n",
    "\n",
    "# Print or save the filtered data\n",
    "print(filtered_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4676570, 8), (2909482, 8))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape, filtered_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_data.to_csv(\"../../data/edited_data.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_video_ids = random.sample(filtered_data['video_id'].unique().tolist(), 10)\n",
    "\n",
    "# Filter the filtered_data to keep only the rows with selected video_ids\n",
    "filtered_data = filtered_data[filtered_data['video_id'].isin(random_video_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>video_id</th>\n",
       "      <th>play_duration</th>\n",
       "      <th>video_duration</th>\n",
       "      <th>time</th>\n",
       "      <th>date</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>watch_ratio</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>313</th>\n",
       "      <td>14</td>\n",
       "      <td>170</td>\n",
       "      <td>8584</td>\n",
       "      <td>10644</td>\n",
       "      <td>2020-07-12 02:58:22.542</td>\n",
       "      <td>20200712.0</td>\n",
       "      <td>1.594494e+09</td>\n",
       "      <td>0.806464</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>671</th>\n",
       "      <td>14</td>\n",
       "      <td>2332</td>\n",
       "      <td>4498</td>\n",
       "      <td>7084</td>\n",
       "      <td>2020-07-18 04:47:09.087</td>\n",
       "      <td>20200718.0</td>\n",
       "      <td>1.595019e+09</td>\n",
       "      <td>0.634952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>721</th>\n",
       "      <td>14</td>\n",
       "      <td>2038</td>\n",
       "      <td>8495</td>\n",
       "      <td>8800</td>\n",
       "      <td>2020-07-18 23:28:55.471</td>\n",
       "      <td>20200718.0</td>\n",
       "      <td>1.595086e+09</td>\n",
       "      <td>0.965341</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>722</th>\n",
       "      <td>14</td>\n",
       "      <td>3996</td>\n",
       "      <td>4449</td>\n",
       "      <td>7900</td>\n",
       "      <td>2020-07-18 23:37:53.61</td>\n",
       "      <td>20200718.0</td>\n",
       "      <td>1.595087e+09</td>\n",
       "      <td>0.563165</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1379</th>\n",
       "      <td>14</td>\n",
       "      <td>5654</td>\n",
       "      <td>4487</td>\n",
       "      <td>7367</td>\n",
       "      <td>2020-07-29 03:16:50.236</td>\n",
       "      <td>20200729.0</td>\n",
       "      <td>1.595964e+09</td>\n",
       "      <td>0.609067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4674577</th>\n",
       "      <td>7162</td>\n",
       "      <td>2332</td>\n",
       "      <td>5528</td>\n",
       "      <td>7084</td>\n",
       "      <td>2020-07-27 05:30:00.165</td>\n",
       "      <td>20200727.0</td>\n",
       "      <td>1.595799e+09</td>\n",
       "      <td>0.780350</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4674913</th>\n",
       "      <td>7162</td>\n",
       "      <td>9995</td>\n",
       "      <td>11870</td>\n",
       "      <td>6567</td>\n",
       "      <td>2020-08-01 06:33:35.72</td>\n",
       "      <td>20200801.0</td>\n",
       "      <td>1.596235e+09</td>\n",
       "      <td>1.807522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676063</th>\n",
       "      <td>7162</td>\n",
       "      <td>4440</td>\n",
       "      <td>11914</td>\n",
       "      <td>17183</td>\n",
       "      <td>2020-08-20 07:51:11.205</td>\n",
       "      <td>20200820.0</td>\n",
       "      <td>1.597881e+09</td>\n",
       "      <td>0.693360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676125</th>\n",
       "      <td>7162</td>\n",
       "      <td>9144</td>\n",
       "      <td>5205</td>\n",
       "      <td>7301</td>\n",
       "      <td>2020-08-21 09:25:30.646</td>\n",
       "      <td>20200821.0</td>\n",
       "      <td>1.597973e+09</td>\n",
       "      <td>0.712916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4676491</th>\n",
       "      <td>7162</td>\n",
       "      <td>6296</td>\n",
       "      <td>14650</td>\n",
       "      <td>9900</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.479798</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>14110 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         user_id  video_id  play_duration  video_duration  \\\n",
       "313           14       170           8584           10644   \n",
       "671           14      2332           4498            7084   \n",
       "721           14      2038           8495            8800   \n",
       "722           14      3996           4449            7900   \n",
       "1379          14      5654           4487            7367   \n",
       "...          ...       ...            ...             ...   \n",
       "4674577     7162      2332           5528            7084   \n",
       "4674913     7162      9995          11870            6567   \n",
       "4676063     7162      4440          11914           17183   \n",
       "4676125     7162      9144           5205            7301   \n",
       "4676491     7162      6296          14650            9900   \n",
       "\n",
       "                            time        date     timestamp  watch_ratio  \n",
       "313      2020-07-12 02:58:22.542  20200712.0  1.594494e+09     0.806464  \n",
       "671      2020-07-18 04:47:09.087  20200718.0  1.595019e+09     0.634952  \n",
       "721      2020-07-18 23:28:55.471  20200718.0  1.595086e+09     0.965341  \n",
       "722       2020-07-18 23:37:53.61  20200718.0  1.595087e+09     0.563165  \n",
       "1379     2020-07-29 03:16:50.236  20200729.0  1.595964e+09     0.609067  \n",
       "...                          ...         ...           ...          ...  \n",
       "4674577  2020-07-27 05:30:00.165  20200727.0  1.595799e+09     0.780350  \n",
       "4674913   2020-08-01 06:33:35.72  20200801.0  1.596235e+09     1.807522  \n",
       "4676063  2020-08-20 07:51:11.205  20200820.0  1.597881e+09     0.693360  \n",
       "4676125  2020-08-21 09:25:30.646  20200821.0  1.597973e+09     0.712916  \n",
       "4676491                      NaN         NaN           NaN     1.479798  \n",
       "\n",
       "[14110 rows x 8 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filtered_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_data.to_csv(\"../../data/filtered_small_matrix.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>user_active_degree</th>\n",
       "      <th>is_lowactive_period</th>\n",
       "      <th>is_live_streamer</th>\n",
       "      <th>is_video_author</th>\n",
       "      <th>follow_user_num</th>\n",
       "      <th>fans_user_num</th>\n",
       "      <th>friend_user_num</th>\n",
       "      <th>register_days</th>\n",
       "      <th>onehot_feat0</th>\n",
       "      <th>...</th>\n",
       "      <th>onehot_feat8</th>\n",
       "      <th>onehot_feat9</th>\n",
       "      <th>onehot_feat10</th>\n",
       "      <th>onehot_feat11</th>\n",
       "      <th>onehot_feat12</th>\n",
       "      <th>onehot_feat13</th>\n",
       "      <th>onehot_feat14</th>\n",
       "      <th>onehot_feat15</th>\n",
       "      <th>onehot_feat16</th>\n",
       "      <th>onehot_feat17</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>107</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>184</td>\n",
       "      <td>6</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>386</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>327</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>186</td>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>27</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>116</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>51</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>16</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>105</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>251</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>122</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>225</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>99</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7171</th>\n",
       "      <td>7171</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>52</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>283</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>259</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7172</th>\n",
       "      <td>7172</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>45</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>109</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>11</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7173</th>\n",
       "      <td>7173</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>615</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>167</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>51</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7174</th>\n",
       "      <td>7174</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>959</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>241</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>107</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7175</th>\n",
       "      <td>7175</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>98</td>\n",
       "      <td>35</td>\n",
       "      <td>33</td>\n",
       "      <td>167</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>132</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>7176 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      user_id  user_active_degree  is_lowactive_period  is_live_streamer  \\\n",
       "0           0                   3                    0                 0   \n",
       "1           1                   2                    0                 0   \n",
       "2           2                   2                    0                 0   \n",
       "3           3                   2                    0                 0   \n",
       "4           4                   2                    0                 0   \n",
       "...       ...                 ...                  ...               ...   \n",
       "7171     7171                   2                    0                 0   \n",
       "7172     7172                   2                    0                 0   \n",
       "7173     7173                   2                    0                 0   \n",
       "7174     7174                   2                    0                 0   \n",
       "7175     7175                   2                    0                 0   \n",
       "\n",
       "      is_video_author  follow_user_num  fans_user_num  friend_user_num  \\\n",
       "0                   0                5              0                0   \n",
       "1                   0              386              4                2   \n",
       "2                   0               27              0                0   \n",
       "3                   0               16              0                0   \n",
       "4                   0              122              4                0   \n",
       "...               ...              ...            ...              ...   \n",
       "7171                1               52              1                0   \n",
       "7172                0               45              2                2   \n",
       "7173                0              615              3                2   \n",
       "7174                0              959              0                0   \n",
       "7175                1               98             35               33   \n",
       "\n",
       "      register_days  onehot_feat0  ...  onehot_feat8  onehot_feat9  \\\n",
       "0               107             0  ...           184             6   \n",
       "1               327             0  ...           186             6   \n",
       "2               116             0  ...            51             2   \n",
       "3               105             0  ...           251             3   \n",
       "4               225             0  ...            99             4   \n",
       "...             ...           ...  ...           ...           ...   \n",
       "7171            283             0  ...           259             1   \n",
       "7172            109             0  ...            11             2   \n",
       "7173            167             0  ...            51             2   \n",
       "7174            241             1  ...           107             3   \n",
       "7175            167             0  ...           132             5   \n",
       "\n",
       "      onehot_feat10  onehot_feat11  onehot_feat12  onehot_feat13  \\\n",
       "0                 3              0            0.0            0.0   \n",
       "1                 2              0            0.0            0.0   \n",
       "2                 3              0            0.0            0.0   \n",
       "3                 2              0            0.0            0.0   \n",
       "4                 2              0            0.0            0.0   \n",
       "...             ...            ...            ...            ...   \n",
       "7171              4              0            1.0            0.0   \n",
       "7172              0              0            1.0            0.0   \n",
       "7173              2              0            1.0            0.0   \n",
       "7174              2              0            0.0            0.0   \n",
       "7175              2              0            0.0            0.0   \n",
       "\n",
       "      onehot_feat14  onehot_feat15  onehot_feat16  onehot_feat17  \n",
       "0               0.0            0.0            0.0            0.0  \n",
       "1               0.0            0.0            0.0            0.0  \n",
       "2               0.0            0.0            0.0            0.0  \n",
       "3               0.0            0.0            0.0            0.0  \n",
       "4               0.0            0.0            0.0            0.0  \n",
       "...             ...            ...            ...            ...  \n",
       "7171            0.0            0.0            0.0            0.0  \n",
       "7172            0.0            0.0            0.0            0.0  \n",
       "7173            0.0            0.0            0.0            0.0  \n",
       "7174            0.0            0.0            0.0            0.0  \n",
       "7175            0.0            0.0            0.0            0.0  \n",
       "\n",
       "[7176 rows x 27 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = \"../../data/user_features.csv\"\n",
    "data = pd.read_csv(path)\n",
    "\n",
    "# Map the 'user_active_degree' values to integers\n",
    "activity_mapping = {'high_active': 3, 'full_active': 2, 'middle_active': 1, 'UNKNOWN': 0}\n",
    "data['user_active_degree'] = data['user_active_degree'].map(activity_mapping)\n",
    "\n",
    "# Remove \"range\" features\n",
    "data = data.drop(columns=[col for col in data.columns if 'range' in col])\n",
    "\n",
    "# Check for columns that do not change between users\n",
    "constant_columns = [col for col in data.columns if data[col].nunique() == 1]\n",
    "\n",
    "# Delete the top 5 features that don't change between users, if there are at least 5 such features\n",
    "constant_columns_to_delete = constant_columns[:5] if len(constant_columns) >= 5 else constant_columns\n",
    "data = data.drop(columns=constant_columns_to_delete)\n",
    "\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.to_csv(\"../../data/filtered_features.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from dataclasses import dataclass\n",
    "import enum\n",
    "from dataclasses import dataclass\n",
    "from obp.dataset import SyntheticBanditDataset\n",
    "from obp.dataset import BaseBanditDataset\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from typing import Callable, Tuple\n",
    "from typing import Optional\n",
    "from obp.dataset import linear_behavior_policy\n",
    "import numpy as np\n",
    "from obp.dataset import linear_behavior_policy\n",
    "from obp.dataset import linear_reward_function\n",
    "from obp.dataset import logistic_reward_function\n",
    "from obp.dataset import logistic_polynomial_reward_function\n",
    "from obp.dataset import polynomial_reward_function\n",
    "from obp.dataset import SyntheticBanditDataset\n",
    "from obp.types import BanditFeedback\n",
    "from obp.utils import sample_action_fast\n",
    "from obp.utils import check_array\n",
    "from obp.utils import softmax\n",
    "from scipy.stats import rankdata\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.utils import check_random_state\n",
    "from sklearn.utils import check_scalar\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from scipy.stats import norm\n",
    "from scipy.stats import truncnorm\n",
    "import pandas as pd\n",
    "import math\n",
    "from typing import Optional, List\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n",
    "\n",
    "path = \"../../data/\"\n",
    "expected_rewards_df = pd.read_csv(path + \"filtered_matrix.csv\")\n",
    "user_features_df = pd.read_csv(path + \"filtered_features.csv\")\n",
    "item_features_df = pd.read_csv(path + \"item_daily_features.csv\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features_df.dropna(inplace=True)\n",
    "\n",
    "# Scale the features, excluding user_id\n",
    "scaler = MinMaxScaler()\n",
    "features_to_scale = user_features_df.columns[1:]\n",
    "user_features_df[features_to_scale] = scaler.fit_transform(user_features_df[features_to_scale])\n",
    "\n",
    "# Remove low variance features\n",
    "variances = user_features_df.var()\n",
    "low_variance_cols = variances[variances < 0.02].index\n",
    "user_features_df.drop(columns=low_variance_cols, inplace=True)\n",
    "\n",
    "# Remove high proportion features\n",
    "def proportion_of_most_common(df):\n",
    "    return df.apply(lambda x: x.value_counts(normalize=True).iloc[0])\n",
    "\n",
    "proportions = proportion_of_most_common(user_features_df)\n",
    "high_proportion_cols = proportions[proportions > 0.95].index\n",
    "user_features_df.drop(columns=high_proportion_cols, inplace=True)\n",
    "\n",
    "# Identify one-hot features\n",
    "all_one_hot_features = [f'onehot_feat{i}' for i in range(18)]  # Adjust this if there are more or fewer one-hot features\n",
    "one_hot_features = [feat for feat in all_one_hot_features if feat in user_features_df.columns]\n",
    "\n",
    "# Apply OneHotEncoder to the one-hot features\n",
    "encoder = OneHotEncoder(sparse=False)\n",
    "encoded_features = encoder.fit_transform(user_features_df[one_hot_features])\n",
    "\n",
    "# Combine the encoded features with the rest of the DataFrame\n",
    "encoded_df = pd.DataFrame(encoded_features, columns=encoder.get_feature_names_out(one_hot_features))\n",
    "user_features_df = pd.concat([user_features_df.reset_index(drop=True), encoded_df], axis=1)\n",
    "user_features_df.drop(columns=one_hot_features, inplace=True)\n",
    "\n",
    "# Apply PCA with automatic number of dimensions\n",
    "pca = PCA(n_components=50)  # Retain 95% of the variance\n",
    "pca_features = pca.fit_transform(user_features_df.drop(columns=['user_id']))\n",
    "\n",
    "# Create a DataFrame for the PCA features\n",
    "pca_df = pd.DataFrame(pca_features, columns=[f'PC{i+1}' for i in range(pca_features.shape[1])])\n",
    "\n",
    "# Combine PCA features with user_id\n",
    "final_df = pd.concat([user_features_df[['user_id']].reset_index(drop=True), pca_df], axis=1)\n",
    "\n",
    "# Cluster users to find patterns\n",
    "kmeans = KMeans(n_clusters=5, random_state=0)\n",
    "clusters = kmeans.fit_predict(final_df.drop(columns=['user_id']))\n",
    "final_df['cluster'] = clusters\n",
    "\n",
    "# Save the final DataFrame to CSV\n",
    "final_df.to_csv(\"../../data/user_features_df.csv\", index=False)\n",
    "\n",
    "# valid_user_ids is a list of user_ids in user_features_df\n",
    "valid_user_ids = final_df['user_id'].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_features_df['upload_dt'] = pd.to_datetime(item_features_df['upload_dt'])\n",
    "\n",
    "# Select only the needed columns\n",
    "item_features_subset = item_features_df[['video_id', 'upload_dt']]\n",
    "\n",
    "# Keep only the first occurrence of each video_id\n",
    "item_features_first = item_features_subset.groupby('video_id').first().reset_index()\n",
    "\n",
    "# Merge the features with expected_rewards_df on video_id\n",
    "merged_df = pd.merge(expected_rewards_df, item_features_first, on='video_id', how='left')\n",
    "\n",
    "# If you want to convert upload_dt to the integer format:\n",
    "merged_df['upload_dt'] = merged_df['upload_dt'].dt.strftime('%Y%m%d').astype(int)\n",
    "\n",
    "expected_rewards_df = merged_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/yh/bx3tt9_n2wn0y0_2vj8_xpkw002bwh/T/ipykernel_32471/603123574.py:2: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  expected_rewards_df.sort_values(by=['user_id', 'video_id'], inplace=True)\n",
      "/var/folders/yh/bx3tt9_n2wn0y0_2vj8_xpkw002bwh/T/ipykernel_32471/603123574.py:3: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  expected_rewards_df['video_id'] = expected_rewards_df.groupby('user_id').cumcount()\n"
     ]
    }
   ],
   "source": [
    "expected_rewards_df = expected_rewards_df[expected_rewards_df['user_id'].isin(valid_user_ids)]\n",
    "expected_rewards_df.sort_values(by=['user_id', 'video_id'], inplace=True)\n",
    "expected_rewards_df['video_id'] = expected_rewards_df.groupby('user_id').cumcount()\n",
    "\n",
    "random_user_ids = expected_rewards_df['user_id'].unique()\n",
    "expected_rewards_df = expected_rewards_df[expected_rewards_df['user_id'].isin(random_user_ids)]\n",
    "\n",
    "columns = user_features_df.columns[1:]\n",
    "\n",
    "\n",
    "def cluster_mean(duration):\n",
    "    if duration == 0:\n",
    "        return 0\n",
    "    remainder = duration % 5000\n",
    "    if remainder == 0:\n",
    "        return duration \n",
    "    return duration + 5000 - remainder\n",
    "\n",
    "\n",
    "# expected_rewards_df['clustered_play_duration'] = expected_rewards_df['play_duration'].apply(cluster_mean)\n",
    "# expected_rewards_df[\"clustered_watch_ratio\"] = expected_rewards_df['clustered_play_duration']/expected_rewards_df['video_duration']\n",
    "# high = expected_rewards_df['clustered_watch_ratio'].quantile(0.95)\n",
    "# expected_rewards_df['clustered_watch_ratio'] = expected_rewards_df['clustered_watch_ratio'].clip(upper=high)\n",
    "# mean = expected_rewards_df['clustered_watch_ratio'].mean()\n",
    "# std = expected_rewards_df['clustered_watch_ratio'].std()\n",
    "# expected_rewards_df['clustered_watch_ratio'] = (expected_rewards_df['clustered_watch_ratio'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['clustered_watch_ratio'].min()\n",
    "# max_watch_ratio = expected_rewards_df['clustered_watch_ratio'].max()\n",
    "# expected_rewards_df['clustered_watch_ratio'] = ((expected_rewards_df['clustered_watch_ratio'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# high=expected_rewards_df['clustered_play_duration'].quantile(0.95)\n",
    "# expected_rewards_df['clustered_play_duration'] = expected_rewards_df['clustered_play_duration'].clip(upper=high)\n",
    "\n",
    "# mean = expected_rewards_df['clustered_play_duration'].mean()\n",
    "# std = expected_rewards_df['clustered_play_duration'].std()\n",
    "# expected_rewards_df['clustered_play_duration'] = (expected_rewards_df['clustered_play_duration'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['clustered_play_duration'].min()\n",
    "# max_watch_ratio = expected_rewards_df['clustered_play_duration'].max()\n",
    "# expected_rewards_df['clustered_play_duration'] = ((expected_rewards_df['clustered_play_duration'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "\n",
    "\n",
    "thresholds = [100, 200, 300, 400, 500, 800, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 6000, 7000, 8000, 9000, 10000, 15000, 20000, 25000, 30000, 40000, 50000]\n",
    "for threshold in thresholds:\n",
    "    column_name = f\"clustered_watch_ratio{int(threshold / 100)}\"\n",
    "    expected_rewards_df[column_name] = np.where(\n",
    "        expected_rewards_df['play_duration'] > threshold,\n",
    "        (threshold) / expected_rewards_df['video_duration'],\n",
    "        0\n",
    "    )# (threshold/2) / expected_rewards_df['video_duration']\n",
    "    mean = expected_rewards_df[column_name].mean()\n",
    "    std = expected_rewards_df[column_name].std()\n",
    "    expected_rewards_df[column_name] = (expected_rewards_df[column_name] - mean) / std\n",
    "    new_min = 0.0\n",
    "    new_max = 1.5\n",
    "    min_watch_ratio = expected_rewards_df[column_name].min()\n",
    "    max_watch_ratio = expected_rewards_df[column_name].max()\n",
    "    expected_rewards_df[column_name] = ((expected_rewards_df[column_name] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "    \n",
    "# mean = expected_rewards_df['clustered_watch_ratio'].mean()\n",
    "# std = expected_rewards_df['clustered_watch_ratio'].std()\n",
    "# expected_rewards_df['clustered_watch_ratio'] = (expected_rewards_df['clustered_watch_ratio'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['clustered_watch_ratio'].min()\n",
    "# max_watch_ratio = expected_rewards_df['clustered_watch_ratio'].max()\n",
    "# expected_rewards_df['clustered_watch_ratio'] = ((expected_rewards_df['clustered_watch_ratio'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "upper_limit = expected_rewards_df['play_duration'].quantile(0.95)\n",
    "expected_rewards_df['play_duration'] = expected_rewards_df['play_duration'].clip(upper=upper_limit)\n",
    "upper_limit = expected_rewards_df['video_duration'].quantile(0.95)\n",
    "expected_rewards_df['video_duration'] = expected_rewards_df['video_duration'].clip(upper=upper_limit)\n",
    "\n",
    "mean = expected_rewards_df['play_duration'].mean()\n",
    "std = expected_rewards_df['play_duration'].std()\n",
    "expected_rewards_df['play_duration'] = (expected_rewards_df['play_duration'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "\n",
    "min_watch_ratio = expected_rewards_df['play_duration'].min()\n",
    "max_watch_ratio = expected_rewards_df['play_duration'].max()\n",
    "\n",
    "expected_rewards_df['play_duration'] = ((expected_rewards_df['play_duration'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "upper_limit = expected_rewards_df['watch_ratio'].quantile(0.95)\n",
    "\n",
    "expected_rewards_df['watch_ratio'] = expected_rewards_df['watch_ratio'].clip(upper=upper_limit)\n",
    "\n",
    "mean = expected_rewards_df['watch_ratio'].mean()\n",
    "std = expected_rewards_df['watch_ratio'].std()\n",
    "expected_rewards_df['watch_ratio'] = (expected_rewards_df['watch_ratio'] - mean) / std\n",
    "\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "\n",
    "min_watch_ratio = expected_rewards_df['watch_ratio'].min()\n",
    "max_watch_ratio = expected_rewards_df['watch_ratio'].max()\n",
    "\n",
    "expected_rewards_df['watch_ratio'] = ((expected_rewards_df['watch_ratio'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# inv_watch_ratio05 is watch_ratio but if bigger than 0.5, it is 0, else -watch_ratio\n",
    "expected_rewards_df['inv_watch_ratio05'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] > 0.5,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio'] - 0.5\n",
    ")\n",
    "mean = expected_rewards_df['inv_watch_ratio05'].mean()\n",
    "std = expected_rewards_df['inv_watch_ratio05'].std()\n",
    "expected_rewards_df['inv_watch_ratio05'] = (expected_rewards_df['inv_watch_ratio05'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['inv_watch_ratio05'].min()\n",
    "max_watch_ratio = expected_rewards_df['inv_watch_ratio05'].max()\n",
    "expected_rewards_df['inv_watch_ratio05'] = ((expected_rewards_df['inv_watch_ratio05'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# inv_watch_ratio05 is watch_ratio but if bigger than 0.5, it is 0, else watch_ratio -0.7\n",
    "expected_rewards_df['inv_watch_ratio07'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] > 0.7,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio'] - 0.7\n",
    ")\n",
    "mean = expected_rewards_df['inv_watch_ratio07'].mean()\n",
    "std = expected_rewards_df['inv_watch_ratio07'].std()\n",
    "expected_rewards_df['inv_watch_ratio07'] = (expected_rewards_df['inv_watch_ratio07'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['inv_watch_ratio07'].min()\n",
    "max_watch_ratio = expected_rewards_df['inv_watch_ratio07'].max()\n",
    "expected_rewards_df['inv_watch_ratio07'] = ((expected_rewards_df['inv_watch_ratio07'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# inv_watch_ratio05 is watch_ratio but if bigger than 0.5, it is 0, else watch_ratio - 0.5\n",
    "expected_rewards_df['inv_watch_ratio03'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] > 0.3,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio'] - 0.3\n",
    ")\n",
    "mean = expected_rewards_df['inv_watch_ratio03'].mean()\n",
    "std = expected_rewards_df['inv_watch_ratio03'].std()\n",
    "expected_rewards_df['inv_watch_ratio03'] = (expected_rewards_df['inv_watch_ratio03'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['inv_watch_ratio03'].min()\n",
    "max_watch_ratio = expected_rewards_df['inv_watch_ratio03'].max()\n",
    "expected_rewards_df['inv_watch_ratio03'] = ((expected_rewards_df['inv_watch_ratio03'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "    \n",
    "\n",
    "\n",
    "# big_watch_ratio13 is watch ratio but if smaller than 1.3, it is 0, else watch_ratio\n",
    "expected_rewards_df['big_watch_ratio11'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] < 1.1,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio']\n",
    ")\n",
    "mean = expected_rewards_df['big_watch_ratio11'].mean()\n",
    "std = expected_rewards_df['big_watch_ratio11'].std()\n",
    "expected_rewards_df['big_watch_ratio11'] = (expected_rewards_df['big_watch_ratio11'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['big_watch_ratio11'].min()\n",
    "max_watch_ratio = expected_rewards_df['big_watch_ratio11'].max()\n",
    "expected_rewards_df['big_watch_ratio11'] = ((expected_rewards_df['big_watch_ratio11'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# big_watch_ratio13 is watch ratio but if smaller than 1.3, it is 0, else watch_ratio\n",
    "expected_rewards_df['big_watch_ratio13'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] < 1.3,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio']\n",
    ")\n",
    "mean = expected_rewards_df['big_watch_ratio13'].mean()\n",
    "std = expected_rewards_df['big_watch_ratio13'].std()\n",
    "expected_rewards_df['big_watch_ratio13'] = (expected_rewards_df['big_watch_ratio13'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['big_watch_ratio13'].min()\n",
    "max_watch_ratio = expected_rewards_df['big_watch_ratio13'].max()\n",
    "expected_rewards_df['big_watch_ratio13'] = ((expected_rewards_df['big_watch_ratio13'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# big_watch_ratio13 is watch ratio but if smaller than 1.3, it is 0, else watch_ratio\n",
    "expected_rewards_df['big_watch_ratio10'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] < 1.0,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio']\n",
    ")\n",
    "mean = expected_rewards_df['big_watch_ratio10'].mean()\n",
    "std = expected_rewards_df['big_watch_ratio10'].std()\n",
    "expected_rewards_df['big_watch_ratio10'] = (expected_rewards_df['big_watch_ratio10'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['big_watch_ratio10'].min()\n",
    "max_watch_ratio = expected_rewards_df['big_watch_ratio10'].max()\n",
    "expected_rewards_df['big_watch_ratio10'] = ((expected_rewards_df['big_watch_ratio10'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "expected_rewards_df['big_watch_ratio08'] = np.where(\n",
    "    expected_rewards_df['watch_ratio'] < 0.8,\n",
    "    0,\n",
    "    expected_rewards_df['watch_ratio']\n",
    ")\n",
    "mean = expected_rewards_df['big_watch_ratio08'].mean()\n",
    "std = expected_rewards_df['big_watch_ratio08'].std()\n",
    "expected_rewards_df['big_watch_ratio08'] = (expected_rewards_df['big_watch_ratio08'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['big_watch_ratio08'].min()\n",
    "max_watch_ratio = expected_rewards_df['big_watch_ratio08'].max()\n",
    "expected_rewards_df['big_watch_ratio08'] = ((expected_rewards_df['big_watch_ratio08'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "\n",
    "mean = expected_rewards_df['video_duration'].mean()\n",
    "std = expected_rewards_df['video_duration'].std()\n",
    "expected_rewards_df['video_duration'] = (expected_rewards_df['video_duration'] - mean) / std\n",
    "new_min = 0.0  \n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['video_duration'].min()\n",
    "max_watch_ratio = expected_rewards_df['video_duration'].max()\n",
    "expected_rewards_df['video_duration'] = ((expected_rewards_df['video_duration'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "\n",
    "# thresholds = [0.7, 1.0, 1.3]\n",
    "# for threshold in thresholds:\n",
    "#     column_name = f\"watch_ratio{int(threshold * 10)}\"\n",
    "#     expected_rewards_df[column_name] = (expected_rewards_df['watch_ratio'] > threshold).astype(int)\n",
    "    \n",
    "# for threshold in thresholds:\n",
    "#     column_name = f\"play_duration{int(threshold * 10)}\"\n",
    "#     expected_rewards_df[column_name] = (expected_rewards_df['play_duration'] > threshold).astype(int)\n",
    "# # play duration less than 0.2 should be 1\n",
    "expected_rewards_df['play_lessthan02'] = (expected_rewards_df['play_duration'] < -0.2).astype(int)\n",
    "\n",
    "column_name_ = f\"watch_lessthan03\"\n",
    "expected_rewards_df[column_name_] = (expected_rewards_df['watch_ratio'] < -0.3).astype(int)\n",
    "\n",
    "\n",
    "avg_watch_ratio = expected_rewards_df.groupby('video_id')['watch_ratio'].mean().rename('avg_watch_ratio')\n",
    "avg_play_duration = expected_rewards_df.groupby('video_id')['play_duration'].mean().rename('avg_play_duration')\n",
    "\n",
    "avg_perusr_watch_ratio = expected_rewards_df.groupby('user_id')['watch_ratio'].mean().rename('avg_perusr_watch_ratio')\n",
    "avg_perusr_play_duration = expected_rewards_df.groupby('user_id')['play_duration'].mean().rename('avg_perusr_play_duration')\n",
    "\n",
    "expected_rewards_df = expected_rewards_df.merge(avg_watch_ratio, on='video_id')\n",
    "expected_rewards_df = expected_rewards_df.merge(avg_play_duration, on='video_id')\n",
    "expected_rewards_df['rational_watch_ratio'] = expected_rewards_df['watch_ratio'] / expected_rewards_df['avg_watch_ratio']\n",
    "expected_rewards_df['diff_play_duration'] = expected_rewards_df['play_duration'] - expected_rewards_df['avg_play_duration']\n",
    "\n",
    "mean = expected_rewards_df['rational_watch_ratio'].mean()\n",
    "std = expected_rewards_df['rational_watch_ratio'].std()\n",
    "expected_rewards_df['rational_watch_ratio'] = (expected_rewards_df['rational_watch_ratio'] - mean) / std\n",
    "new_min = 0.0  \n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['rational_watch_ratio'].min()\n",
    "max_watch_ratio = expected_rewards_df['rational_watch_ratio'].max()\n",
    "expected_rewards_df['rational_watch_ratio'] = ((expected_rewards_df['rational_watch_ratio'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "mean = expected_rewards_df['diff_play_duration'].mean()\n",
    "std = expected_rewards_df['diff_play_duration'].std()\n",
    "expected_rewards_df['diff_play_duration'] = (expected_rewards_df['diff_play_duration'] - mean) / std\n",
    "new_min = 0.0  \n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['diff_play_duration'].min()\n",
    "max_watch_ratio = expected_rewards_df['diff_play_duration'].max()\n",
    "expected_rewards_df['diff_play_duration'] = ((expected_rewards_df['diff_play_duration'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# mean = expected_rewards_df['play_progress'].mean()\n",
    "# std = expected_rewards_df['play_progress'].std()\n",
    "# expected_rewards_df['play_progress'] = (expected_rewards_df['play_progress'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['play_progress'].min()\n",
    "# max_watch_ratio = expected_rewards_df['play_progress'].max()\n",
    "# expected_rewards_df['play_progress'] = ((expected_rewards_df['play_progress'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# # avg_play_duration = expected_rewards_df[\"tot_play_duration\"]/expected_rewards_df[\"play_cnt\"]\n",
    "# avg_play_duration = expected_rewards_df['tot_play_duration'].mean()\n",
    "\n",
    "# avg_play_cnt = item_features_df['play_cnt'].mean()\n",
    "# avg_play_user_num = item_features_df['play_user_num'].mean()\n",
    "# avg_valid_play_cnt = item_features_df['valid_play_cnt'].mean()\n",
    "# avg_valid_play_user_num = item_features_df['valid_play_user_num'].mean()\n",
    "# avg_play_progress = item_features_df['play_progress'].mean()\n",
    "\n",
    "\n",
    "# expected_rewards_df[\"equity_metric_pd\"] = avg_play_duration-expected_rewards_df[\"tot_play_duration\"]\n",
    "# mean = expected_rewards_df['equity_metric_pd'].mean()\n",
    "# std = expected_rewards_df['equity_metric_pd'].std()\n",
    "# expected_rewards_df['equity_metric_pd'] = (expected_rewards_df['equity_metric_pd'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['equity_metric_pd'].min()\n",
    "# max_watch_ratio = expected_rewards_df['equity_metric_pd'].max()\n",
    "# expected_rewards_df['equity_metric_pd'] = ((expected_rewards_df['equity_metric_pd'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# expected_rewards_df[\"equity_metric_user_num\"] = avg_play_user_num-expected_rewards_df[\"play_user_num\"]\n",
    "# mean = expected_rewards_df['equity_metric_user_num'].mean()\n",
    "# std = expected_rewards_df['equity_metric_user_num'].std()\n",
    "# expected_rewards_df['equity_metric_user_num'] = (expected_rewards_df['equity_metric_user_num'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['equity_metric_user_num'].min()\n",
    "# max_watch_ratio = expected_rewards_df['equity_metric_user_num'].max()\n",
    "# expected_rewards_df['equity_metric_user_num'] = ((expected_rewards_df['equity_metric_user_num'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# expected_rewards_df[\"long_ratio\"] = expected_rewards_df[\"long_time_play_cnt\"]/expected_rewards_df[\"play_cnt\"]\n",
    "# mean = expected_rewards_df['long_ratio'].mean()\n",
    "# std = expected_rewards_df['long_ratio'].std()\n",
    "# expected_rewards_df['long_ratio'] = (expected_rewards_df['long_ratio'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['long_ratio'].min()\n",
    "# max_watch_ratio = expected_rewards_df['long_ratio'].max()\n",
    "# expected_rewards_df['long_ratio'] = ((expected_rewards_df['long_ratio'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# expected_rewards_df[\"short_ratio\"] = expected_rewards_df[\"short_time_play_cnt\"]/expected_rewards_df[\"play_cnt\"]\n",
    "\n",
    "# mean = expected_rewards_df['play_progress'].mean()\n",
    "# std = expected_rewards_df['play_progress'].std()\n",
    "# expected_rewards_df['play_progress'] = (expected_rewards_df['play_progress'] - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.5\n",
    "# min_watch_ratio = expected_rewards_df['play_progress'].min()\n",
    "# max_watch_ratio = expected_rewards_df['play_progress'].max()\n",
    "# expected_rewards_df['play_progress'] = ((expected_rewards_df['play_progress'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "\n",
    "# avg_play_dur = expected_rewards_df['tot_play_duration']/expected_rewards_df['play_cnt']\n",
    "\n",
    "\n",
    "expected_rewards_df['upload_dt'] = pd.to_datetime(expected_rewards_df['upload_dt'].astype(str), format='%Y%m%d')\n",
    "most_recent_date = expected_rewards_df['upload_dt'].max()\n",
    "expected_rewards_df['days_since_upload'] = (most_recent_date - expected_rewards_df['upload_dt']).dt.days\n",
    "mean = expected_rewards_df['days_since_upload'].mean()\n",
    "std = expected_rewards_df['days_since_upload'].std()\n",
    "expected_rewards_df['days_since_upload'] = (expected_rewards_df['days_since_upload'] - mean) / std\n",
    "new_min = 0.0\n",
    "new_max = 1.5\n",
    "min_watch_ratio = expected_rewards_df['days_since_upload'].min()\n",
    "max_watch_ratio = expected_rewards_df['days_since_upload'].max()\n",
    "expected_rewards_df['days_since_upload'] = ((expected_rewards_df['days_since_upload'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "\n",
    "# mean = avg_perusr_play_duration.mean()\n",
    "# std = avg_perusr_play_duration.std()\n",
    "# expected_rewards_df['avg_perusr_play_duration'] = (avg_perusr_play_duration - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.0\n",
    "# min_watch_ratio = expected_rewards_df['avg_perusr_play_duration'].min()\n",
    "# max_watch_ratio = expected_rewards_df['avg_perusr_play_duration'].max()\n",
    "# expected_rewards_df['avg_perusr_play_duration'] = ((expected_rewards_df['avg_perusr_play_duration'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "# mean = avg_perusr_watch_ratio.mean()\n",
    "# std = avg_perusr_watch_ratio.std()\n",
    "# expected_rewards_df['avg_perusr_watch_ratio'] = (avg_perusr_watch_ratio - mean) / std\n",
    "# new_min = 0.0\n",
    "# new_max = 1.0\n",
    "# min_watch_ratio = expected_rewards_df['avg_perusr_watch_ratio'].min()\n",
    "# max_watch_ratio = expected_rewards_df['avg_perusr_watch_ratio'].max()\n",
    "# expected_rewards_df['avg_perusr_watch_ratio'] = ((expected_rewards_df['avg_perusr_watch_ratio'] - min_watch_ratio) / (max_watch_ratio - min_watch_ratio)) * (new_max - new_min) + new_min\n",
    "\n",
    "\n",
    "expected_rewards_df.to_csv(\"../../data/expected_rewards_df.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6410107052054337"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_rat = expected_rewards_df['watch_ratio'].mean()\n",
    "max_rat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5237514915524165"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_ = expected_rewards_df['play_duration'].mean()\n",
    "max_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "315072"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_ = expected_rewards_df['video_duration'].max()\n",
    "max_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# save both data\n",
    "expected_rewards_df.to_csv(\"../../data/expected_rewards_df.csv\", index=False)\n",
    "user_features_df.to_csv(\"../../data/user_features_df.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_video_ids_count = expected_rewards_df['video_id'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2062"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unique_video_ids_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
