{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from huggingface_hub import login\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import argparse\n",
    "from datasets import load_from_disk\n",
    "import json\n",
    "from transformers import pipeline\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "\n",
    "from transformers import pipeline\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "import json\n",
    "# plot distribution of ctr\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = json.load(open('../configs/mind-sport-abs-entangled/data.json'))\n",
    "\n",
    "data = load_from_disk(config[\"data_path\"].format(**config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train_df = data['reward'].to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Abstract_len</th>\n",
       "      <th>__index_level_0__</th>\n",
       "      <th>headline_length</th>\n",
       "      <th>ctr</th>\n",
       "      <th>emotion_abstract</th>\n",
       "      <th>is_west_coast</th>\n",
       "      <th>is_central</th>\n",
       "      <th>is_east_coast</th>\n",
       "      <th>popularity</th>\n",
       "      <th>score</th>\n",
       "      <th>score_no_popularity</th>\n",
       "      <th>score_clean</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>49.789200</td>\n",
       "      <td>19933.388900</td>\n",
       "      <td>11.191800</td>\n",
       "      <td>0.468503</td>\n",
       "      <td>0.376384</td>\n",
       "      <td>0.084100</td>\n",
       "      <td>0.341700</td>\n",
       "      <td>0.327100</td>\n",
       "      <td>-13.306550</td>\n",
       "      <td>0.202194</td>\n",
       "      <td>0.468325</td>\n",
       "      <td>0.404293</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>22.634761</td>\n",
       "      <td>11451.147983</td>\n",
       "      <td>3.324202</td>\n",
       "      <td>0.171566</td>\n",
       "      <td>0.194722</td>\n",
       "      <td>0.277552</td>\n",
       "      <td>0.474303</td>\n",
       "      <td>0.469178</td>\n",
       "      <td>8.078831</td>\n",
       "      <td>0.191950</td>\n",
       "      <td>0.172807</td>\n",
       "      <td>0.163622</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>20.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.015222</td>\n",
       "      <td>0.010451</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-39.291969</td>\n",
       "      <td>-0.560800</td>\n",
       "      <td>0.005752</td>\n",
       "      <td>-0.078474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>26.000000</td>\n",
       "      <td>10089.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>0.349519</td>\n",
       "      <td>0.222880</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-18.843198</td>\n",
       "      <td>0.081047</td>\n",
       "      <td>0.348956</td>\n",
       "      <td>0.291402</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>59.000000</td>\n",
       "      <td>19874.500000</td>\n",
       "      <td>11.000000</td>\n",
       "      <td>0.479245</td>\n",
       "      <td>0.357927</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-12.468786</td>\n",
       "      <td>0.200961</td>\n",
       "      <td>0.477326</td>\n",
       "      <td>0.410301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>71.000000</td>\n",
       "      <td>29723.750000</td>\n",
       "      <td>13.000000</td>\n",
       "      <td>0.574618</td>\n",
       "      <td>0.509017</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-7.003462</td>\n",
       "      <td>0.326114</td>\n",
       "      <td>0.576350</td>\n",
       "      <td>0.506088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>109.000000</td>\n",
       "      <td>39996.000000</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>0.970515</td>\n",
       "      <td>0.982338</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>4.582940</td>\n",
       "      <td>0.977634</td>\n",
       "      <td>0.994264</td>\n",
       "      <td>0.982148</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       Abstract_len  __index_level_0__  headline_length           ctr  \\\n",
       "count  10000.000000       10000.000000     10000.000000  10000.000000   \n",
       "mean      49.789200       19933.388900        11.191800      0.468503   \n",
       "std       22.634761       11451.147983         3.324202      0.171566   \n",
       "min       20.000000           2.000000         2.000000      0.015222   \n",
       "25%       26.000000       10089.000000         9.000000      0.349519   \n",
       "50%       59.000000       19874.500000        11.000000      0.479245   \n",
       "75%       71.000000       29723.750000        13.000000      0.574618   \n",
       "max      109.000000       39996.000000        35.000000      0.970515   \n",
       "\n",
       "       emotion_abstract  is_west_coast    is_central  is_east_coast  \\\n",
       "count      10000.000000   10000.000000  10000.000000   10000.000000   \n",
       "mean           0.376384       0.084100      0.341700       0.327100   \n",
       "std            0.194722       0.277552      0.474303       0.469178   \n",
       "min            0.010451       0.000000      0.000000       0.000000   \n",
       "25%            0.222880       0.000000      0.000000       0.000000   \n",
       "50%            0.357927       0.000000      0.000000       0.000000   \n",
       "75%            0.509017       0.000000      1.000000       1.000000   \n",
       "max            0.982338       1.000000      1.000000       1.000000   \n",
       "\n",
       "         popularity         score  score_no_popularity   score_clean  \n",
       "count  10000.000000  10000.000000         10000.000000  10000.000000  \n",
       "mean     -13.306550      0.202194             0.468325      0.404293  \n",
       "std        8.078831      0.191950             0.172807      0.163622  \n",
       "min      -39.291969     -0.560800             0.005752     -0.078474  \n",
       "25%      -18.843198      0.081047             0.348956      0.291402  \n",
       "50%      -12.468786      0.200961             0.477326      0.410301  \n",
       "75%       -7.003462      0.326114             0.576350      0.506088  \n",
       "max        4.582940      0.977634             0.994264      0.982148  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Abstract_len</th>\n",
       "      <th>__index_level_0__</th>\n",
       "      <th>headline_length</th>\n",
       "      <th>ctr</th>\n",
       "      <th>emotion_abstract</th>\n",
       "      <th>is_west_coast</th>\n",
       "      <th>is_central</th>\n",
       "      <th>is_east_coast</th>\n",
       "      <th>popularity</th>\n",
       "      <th>score</th>\n",
       "      <th>score_no_popularity</th>\n",
       "      <th>score_clean</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>49.789200</td>\n",
       "      <td>19933.388900</td>\n",
       "      <td>11.191800</td>\n",
       "      <td>0.468503</td>\n",
       "      <td>0.376384</td>\n",
       "      <td>0.084100</td>\n",
       "      <td>0.341700</td>\n",
       "      <td>0.327100</td>\n",
       "      <td>-13.306550</td>\n",
       "      <td>0.202194</td>\n",
       "      <td>0.468325</td>\n",
       "      <td>0.404293</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>22.634761</td>\n",
       "      <td>11451.147983</td>\n",
       "      <td>3.324202</td>\n",
       "      <td>0.171566</td>\n",
       "      <td>0.194722</td>\n",
       "      <td>0.277552</td>\n",
       "      <td>0.474303</td>\n",
       "      <td>0.469178</td>\n",
       "      <td>8.078831</td>\n",
       "      <td>0.191950</td>\n",
       "      <td>0.172807</td>\n",
       "      <td>0.163622</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>20.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.015222</td>\n",
       "      <td>0.010451</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-39.291969</td>\n",
       "      <td>-0.560800</td>\n",
       "      <td>0.005752</td>\n",
       "      <td>-0.078474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>26.000000</td>\n",
       "      <td>10089.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>0.349519</td>\n",
       "      <td>0.222880</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-18.843198</td>\n",
       "      <td>0.081047</td>\n",
       "      <td>0.348956</td>\n",
       "      <td>0.291402</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>59.000000</td>\n",
       "      <td>19874.500000</td>\n",
       "      <td>11.000000</td>\n",
       "      <td>0.479245</td>\n",
       "      <td>0.357927</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-12.468786</td>\n",
       "      <td>0.200961</td>\n",
       "      <td>0.477326</td>\n",
       "      <td>0.410301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>71.000000</td>\n",
       "      <td>29723.750000</td>\n",
       "      <td>13.000000</td>\n",
       "      <td>0.574618</td>\n",
       "      <td>0.509017</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-7.003462</td>\n",
       "      <td>0.326114</td>\n",
       "      <td>0.576350</td>\n",
       "      <td>0.506088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>109.000000</td>\n",
       "      <td>39996.000000</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>0.970515</td>\n",
       "      <td>0.982338</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>4.582940</td>\n",
       "      <td>0.977634</td>\n",
       "      <td>0.994264</td>\n",
       "      <td>0.982148</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       Abstract_len  __index_level_0__  headline_length           ctr  \\\n",
       "count  10000.000000       10000.000000     10000.000000  10000.000000   \n",
       "mean      49.789200       19933.388900        11.191800      0.468503   \n",
       "std       22.634761       11451.147983         3.324202      0.171566   \n",
       "min       20.000000           2.000000         2.000000      0.015222   \n",
       "25%       26.000000       10089.000000         9.000000      0.349519   \n",
       "50%       59.000000       19874.500000        11.000000      0.479245   \n",
       "75%       71.000000       29723.750000        13.000000      0.574618   \n",
       "max      109.000000       39996.000000        35.000000      0.970515   \n",
       "\n",
       "       emotion_abstract  is_west_coast    is_central  is_east_coast  \\\n",
       "count      10000.000000   10000.000000  10000.000000   10000.000000   \n",
       "mean           0.376384       0.084100      0.341700       0.327100   \n",
       "std            0.194722       0.277552      0.474303       0.469178   \n",
       "min            0.010451       0.000000      0.000000       0.000000   \n",
       "25%            0.222880       0.000000      0.000000       0.000000   \n",
       "50%            0.357927       0.000000      0.000000       0.000000   \n",
       "75%            0.509017       0.000000      1.000000       1.000000   \n",
       "max            0.982338       1.000000      1.000000       1.000000   \n",
       "\n",
       "         popularity         score  score_no_popularity   score_clean  \n",
       "count  10000.000000  10000.000000         10000.000000  10000.000000  \n",
       "mean     -13.306550      0.202194             0.468325      0.404293  \n",
       "std        8.078831      0.191950             0.172807      0.163622  \n",
       "min      -39.291969     -0.560800             0.005752     -0.078474  \n",
       "25%      -18.843198      0.081047             0.348956      0.291402  \n",
       "50%      -12.468786      0.200961             0.477326      0.410301  \n",
       "75%       -7.003462      0.326114             0.576350      0.506088  \n",
       "max        4.582940      0.977634             0.994264      0.982148  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>score</th>\n",
       "      <th>ctr</th>\n",
       "      <th>popularity</th>\n",
       "      <th>emotion_abstract</th>\n",
       "      <th>is_west_coast</th>\n",
       "      <th>is_central</th>\n",
       "      <th>is_east_coast</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.202194</td>\n",
       "      <td>0.468503</td>\n",
       "      <td>-13.306550</td>\n",
       "      <td>0.376384</td>\n",
       "      <td>0.084100</td>\n",
       "      <td>0.341700</td>\n",
       "      <td>0.327100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.191950</td>\n",
       "      <td>0.171566</td>\n",
       "      <td>8.078831</td>\n",
       "      <td>0.194722</td>\n",
       "      <td>0.277552</td>\n",
       "      <td>0.474303</td>\n",
       "      <td>0.469178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>-0.560800</td>\n",
       "      <td>0.015222</td>\n",
       "      <td>-39.291969</td>\n",
       "      <td>0.010451</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.081047</td>\n",
       "      <td>0.349519</td>\n",
       "      <td>-18.843198</td>\n",
       "      <td>0.222880</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.200961</td>\n",
       "      <td>0.479245</td>\n",
       "      <td>-12.468786</td>\n",
       "      <td>0.357927</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.326114</td>\n",
       "      <td>0.574618</td>\n",
       "      <td>-7.003462</td>\n",
       "      <td>0.509017</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>0.977634</td>\n",
       "      <td>0.970515</td>\n",
       "      <td>4.582940</td>\n",
       "      <td>0.982338</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              score           ctr    popularity  emotion_abstract  \\\n",
       "count  10000.000000  10000.000000  10000.000000      10000.000000   \n",
       "mean       0.202194      0.468503    -13.306550          0.376384   \n",
       "std        0.191950      0.171566      8.078831          0.194722   \n",
       "min       -0.560800      0.015222    -39.291969          0.010451   \n",
       "25%        0.081047      0.349519    -18.843198          0.222880   \n",
       "50%        0.200961      0.479245    -12.468786          0.357927   \n",
       "75%        0.326114      0.574618     -7.003462          0.509017   \n",
       "max        0.977634      0.970515      4.582940          0.982338   \n",
       "\n",
       "       is_west_coast    is_central  is_east_coast  \n",
       "count   10000.000000  10000.000000   10000.000000  \n",
       "mean        0.084100      0.341700       0.327100  \n",
       "std         0.277552      0.474303       0.469178  \n",
       "min         0.000000      0.000000       0.000000  \n",
       "25%         0.000000      0.000000       0.000000  \n",
       "50%         0.000000      0.000000       0.000000  \n",
       "75%         0.000000      1.000000       1.000000  \n",
       "max         1.000000      1.000000       1.000000  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df[['score', 'ctr', 'popularity', 'emotion_abstract', 'is_west_coast', 'is_central', 'is_east_coast']].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>score</th>\n",
       "      <th>ctr</th>\n",
       "      <th>popularity</th>\n",
       "      <th>emotion_abstract</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>score</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.605545</td>\n",
       "      <td>0.533455</td>\n",
       "      <td>-0.509489</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ctr</th>\n",
       "      <td>0.605545</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.342794</td>\n",
       "      <td>0.348193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>popularity</th>\n",
       "      <td>0.533455</td>\n",
       "      <td>-0.342794</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.977218</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>emotion_abstract</th>\n",
       "      <td>-0.509489</td>\n",
       "      <td>0.348193</td>\n",
       "      <td>-0.977218</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     score       ctr  popularity  emotion_abstract\n",
       "score             1.000000  0.605545    0.533455         -0.509489\n",
       "ctr               0.605545  1.000000   -0.342794          0.348193\n",
       "popularity        0.533455 -0.342794    1.000000         -0.977218\n",
       "emotion_abstract -0.509489  0.348193   -0.977218          1.000000"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df[['score', 'ctr', 'popularity', 'emotion_abstract']].corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run 2sls regression\n",
    "from linearmodels.iv import IV2SLS\n",
    "\n",
    "# score ~ effort | is_central\n",
    "data_train_df['const'] = 1\n",
    "iv = IV2SLS(dependent=data_train_df['score'],\n",
    "            exog=data_train_df[['const']],\n",
    "            endog=data_train_df['popularity'],\n",
    "            instruments=data_train_df['is_central']).fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          IV-2SLS Estimation Summary                          \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                      0.2734\n",
      "Estimator:                    IV-2SLS   Adj. R-squared:                 0.2733\n",
      "No. Observations:               10000   F-statistic:                    138.27\n",
      "Date:                Mon, Aug 04 2025   P-value (F-stat)                0.0000\n",
      "Time:                        19:25:32   Distribution:                  chi2(1)\n",
      "Cov. Estimator:                robust                                         \n",
      "                                                                              \n",
      "                             Parameter Estimates                              \n",
      "==============================================================================\n",
      "            Parameter  Std. Err.     T-stat    P-value    Lower CI    Upper CI\n",
      "------------------------------------------------------------------------------\n",
      "const          0.4043     0.0172     23.477     0.0000      0.3705      0.4380\n",
      "popularity     0.0152     0.0013     11.759     0.0000      0.0127      0.0177\n",
      "==============================================================================\n",
      "\n",
      "Endogenous: popularity\n",
      "Instruments: is_central\n",
      "Robust Covariance (Heteroskedastic)\n",
      "Debiased: False\n"
     ]
    }
   ],
   "source": [
    "print(iv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.015187912904688128"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# print iv coefs\n",
    "iv.params['popularity']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            OLS Regression Results                            \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                       0.989\n",
      "Model:                            OLS   Adj. R-squared:                  0.989\n",
      "Method:                 Least Squares   F-statistic:                 4.466e+05\n",
      "Date:                Mon, 04 Aug 2025   Prob (F-statistic):               0.00\n",
      "Time:                        19:25:32   Log-Likelihood:                 24835.\n",
      "No. Observations:               10000   AIC:                        -4.966e+04\n",
      "Df Residuals:                    9997   BIC:                        -4.964e+04\n",
      "Df Model:                           2                                         \n",
      "Covariance Type:            nonrobust                                         \n",
      "==============================================================================\n",
      "                 coef    std err          t      P>|t|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------\n",
      "const         -0.0006      0.001     -1.006      0.314      -0.002       0.001\n",
      "popularity     0.0200   2.66e-05    749.677      0.000       0.020       0.020\n",
      "ctr            0.9995      0.001    797.608      0.000       0.997       1.002\n",
      "==============================================================================\n",
      "Omnibus:                        6.535   Durbin-Watson:                   1.975\n",
      "Prob(Omnibus):                  0.038   Jarque-Bera (JB):                6.526\n",
      "Skew:                          -0.055   Prob(JB):                       0.0383\n",
      "Kurtosis:                       2.939   Cond. No.                         104.\n",
      "==============================================================================\n",
      "\n",
      "Notes:\n",
      "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
     ]
    }
   ],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm\n",
    "\n",
    "X = data_train_df[['popularity', 'ctr']]\n",
    "X = sm.add_constant(X)\n",
    "y = data_train_df['score']\n",
    "model = sm.OLS(y, X)\n",
    "results = model.fit()\n",
    "print(results.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            OLS Regression Results                            \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                       0.285\n",
      "Model:                            OLS   Adj. R-squared:                  0.285\n",
      "Method:                 Least Squares   F-statistic:                     3977.\n",
      "Date:                Mon, 04 Aug 2025   Prob (F-statistic):               0.00\n",
      "Time:                        19:25:32   Log-Likelihood:                 3990.7\n",
      "No. Observations:               10000   AIC:                            -7977.\n",
      "Df Residuals:                    9998   BIC:                            -7963.\n",
      "Df Model:                           1                                         \n",
      "Covariance Type:            nonrobust                                         \n",
      "==============================================================================\n",
      "                 coef    std err          t      P>|t|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------\n",
      "const          0.3709      0.003    118.532      0.000       0.365       0.377\n",
      "popularity     0.0127      0.000     63.063      0.000       0.012       0.013\n",
      "==============================================================================\n",
      "Omnibus:                       30.105   Durbin-Watson:                   1.958\n",
      "Prob(Omnibus):                  0.000   Jarque-Bera (JB):               30.335\n",
      "Skew:                           0.130   Prob(JB):                     2.59e-07\n",
      "Kurtosis:                       3.072   Cond. No.                         30.1\n",
      "==============================================================================\n",
      "\n",
      "Notes:\n",
      "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
     ]
    }
   ],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm\n",
    "\n",
    "X = data_train_df['popularity']\n",
    "X = sm.add_constant(X)\n",
    "y = data_train_df['score']\n",
    "model = sm.OLS(y, X)\n",
    "results = model.fit()\n",
    "print(results.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            OLS Regression Results                            \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                       0.367\n",
      "Model:                            OLS   Adj. R-squared:                  0.367\n",
      "Method:                 Least Squares   F-statistic:                     5789.\n",
      "Date:                Mon, 04 Aug 2025   Prob (F-statistic):               0.00\n",
      "Time:                        19:25:32   Log-Likelihood:                 4600.2\n",
      "No. Observations:               10000   AIC:                            -9196.\n",
      "Df Residuals:                    9998   BIC:                            -9182.\n",
      "Df Model:                           1                                         \n",
      "Covariance Type:            nonrobust                                         \n",
      "==============================================================================\n",
      "                 coef    std err          t      P>|t|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------\n",
      "const         -0.1152      0.004    -25.933      0.000      -0.124      -0.107\n",
      "ctr            0.6775      0.009     76.084      0.000       0.660       0.695\n",
      "==============================================================================\n",
      "Omnibus:                      265.243   Durbin-Watson:                   2.008\n",
      "Prob(Omnibus):                  0.000   Jarque-Bera (JB):              286.376\n",
      "Skew:                          -0.414   Prob(JB):                     6.52e-63\n",
      "Kurtosis:                       2.951   Cond. No.                         7.14\n",
      "==============================================================================\n",
      "\n",
      "Notes:\n",
      "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
     ]
    }
   ],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm\n",
    "\n",
    "X = data_train_df['ctr']\n",
    "X = sm.add_constant(X)\n",
    "y = data_train_df['score']\n",
    "model = sm.OLS(y, X)\n",
    "results = model.fit()\n",
    "print(results.summary())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-up",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
