{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from huggingface_hub import login\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import argparse\n",
    "from datasets import load_from_disk\n",
    "import json\n",
    "from transformers import pipeline\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "\n",
    "from transformers import pipeline\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "import json\n",
    "# plot distribution of ctr\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = json.load(open('../configs/mind-news-abs-entangled/data.json'))\n",
    "\n",
    "data = load_from_disk(config[\"data_path\"].format(**config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train_df = data['reward'].to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>AbstractSent</th>\n",
       "      <th>TitleSent</th>\n",
       "      <th>Abstract_len</th>\n",
       "      <th>__index_level_0__</th>\n",
       "      <th>headline_length</th>\n",
       "      <th>ctr</th>\n",
       "      <th>emotion_abstract</th>\n",
       "      <th>is_political</th>\n",
       "      <th>popularity</th>\n",
       "      <th>score</th>\n",
       "      <th>score_no_popularity</th>\n",
       "      <th>score_clean</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>5865.000000</td>\n",
       "      <td>5865.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.278236</td>\n",
       "      <td>0.371027</td>\n",
       "      <td>51.060800</td>\n",
       "      <td>19933.388900</td>\n",
       "      <td>11.051400</td>\n",
       "      <td>0.367295</td>\n",
       "      <td>0.275136</td>\n",
       "      <td>0.420000</td>\n",
       "      <td>-5.292557</td>\n",
       "      <td>-0.163734</td>\n",
       "      <td>0.365522</td>\n",
       "      <td>0.263143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.183100</td>\n",
       "      <td>0.184043</td>\n",
       "      <td>21.856529</td>\n",
       "      <td>11451.147983</td>\n",
       "      <td>3.232275</td>\n",
       "      <td>0.182384</td>\n",
       "      <td>0.181634</td>\n",
       "      <td>0.493583</td>\n",
       "      <td>3.649938</td>\n",
       "      <td>0.374447</td>\n",
       "      <td>0.272021</td>\n",
       "      <td>0.256942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.006778</td>\n",
       "      <td>0.015100</td>\n",
       "      <td>20.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.010144</td>\n",
       "      <td>0.006778</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-19.314130</td>\n",
       "      <td>-1.892311</td>\n",
       "      <td>-0.544550</td>\n",
       "      <td>-0.686546</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.137603</td>\n",
       "      <td>0.224585</td>\n",
       "      <td>27.000000</td>\n",
       "      <td>10089.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>0.222399</td>\n",
       "      <td>0.136155</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-7.370638</td>\n",
       "      <td>-0.381551</td>\n",
       "      <td>0.177677</td>\n",
       "      <td>0.086886</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.224702</td>\n",
       "      <td>0.351585</td>\n",
       "      <td>60.000000</td>\n",
       "      <td>19874.500000</td>\n",
       "      <td>11.000000</td>\n",
       "      <td>0.345777</td>\n",
       "      <td>0.221131</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-4.207797</td>\n",
       "      <td>-0.130132</td>\n",
       "      <td>0.361460</td>\n",
       "      <td>0.261492</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.384900</td>\n",
       "      <td>0.488557</td>\n",
       "      <td>71.000000</td>\n",
       "      <td>29723.750000</td>\n",
       "      <td>13.000000</td>\n",
       "      <td>0.485294</td>\n",
       "      <td>0.378504</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-2.517367</td>\n",
       "      <td>0.096742</td>\n",
       "      <td>0.546601</td>\n",
       "      <td>0.433981</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>0.973050</td>\n",
       "      <td>0.964121</td>\n",
       "      <td>117.000000</td>\n",
       "      <td>39996.000000</td>\n",
       "      <td>48.000000</td>\n",
       "      <td>0.964121</td>\n",
       "      <td>0.973050</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.155115</td>\n",
       "      <td>1.091538</td>\n",
       "      <td>1.372134</td>\n",
       "      <td>1.245182</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       AbstractSent    TitleSent  Abstract_len  __index_level_0__  \\\n",
       "count   5865.000000  5865.000000  10000.000000       10000.000000   \n",
       "mean       0.278236     0.371027     51.060800       19933.388900   \n",
       "std        0.183100     0.184043     21.856529       11451.147983   \n",
       "min        0.006778     0.015100     20.000000           2.000000   \n",
       "25%        0.137603     0.224585     27.000000       10089.000000   \n",
       "50%        0.224702     0.351585     60.000000       19874.500000   \n",
       "75%        0.384900     0.488557     71.000000       29723.750000   \n",
       "max        0.973050     0.964121    117.000000       39996.000000   \n",
       "\n",
       "       headline_length           ctr  emotion_abstract  is_political  \\\n",
       "count     10000.000000  10000.000000      10000.000000  10000.000000   \n",
       "mean         11.051400      0.367295          0.275136      0.420000   \n",
       "std           3.232275      0.182384          0.181634      0.493583   \n",
       "min           2.000000      0.010144          0.006778      0.000000   \n",
       "25%           9.000000      0.222399          0.136155      0.000000   \n",
       "50%          11.000000      0.345777          0.221131      0.000000   \n",
       "75%          13.000000      0.485294          0.378504      1.000000   \n",
       "max          48.000000      0.964121          0.973050      1.000000   \n",
       "\n",
       "         popularity         score  score_no_popularity   score_clean  \n",
       "count  10000.000000  10000.000000         10000.000000  10000.000000  \n",
       "mean      -5.292557     -0.163734             0.365522      0.263143  \n",
       "std        3.649938      0.374447             0.272021      0.256942  \n",
       "min      -19.314130     -1.892311            -0.544550     -0.686546  \n",
       "25%       -7.370638     -0.381551             0.177677      0.086886  \n",
       "50%       -4.207797     -0.130132             0.361460      0.261492  \n",
       "75%       -2.517367      0.096742             0.546601      0.433981  \n",
       "max        0.155115      1.091538             1.372134      1.245182  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>AbstractSent</th>\n",
       "      <th>TitleSent</th>\n",
       "      <th>Abstract_len</th>\n",
       "      <th>__index_level_0__</th>\n",
       "      <th>headline_length</th>\n",
       "      <th>ctr</th>\n",
       "      <th>emotion_abstract</th>\n",
       "      <th>is_political</th>\n",
       "      <th>popularity</th>\n",
       "      <th>score</th>\n",
       "      <th>score_no_popularity</th>\n",
       "      <th>score_clean</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>5865.000000</td>\n",
       "      <td>5865.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.278236</td>\n",
       "      <td>0.371027</td>\n",
       "      <td>51.060800</td>\n",
       "      <td>19933.388900</td>\n",
       "      <td>11.051400</td>\n",
       "      <td>0.367295</td>\n",
       "      <td>0.275136</td>\n",
       "      <td>0.420000</td>\n",
       "      <td>-5.292557</td>\n",
       "      <td>-0.163734</td>\n",
       "      <td>0.365522</td>\n",
       "      <td>0.263143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.183100</td>\n",
       "      <td>0.184043</td>\n",
       "      <td>21.856529</td>\n",
       "      <td>11451.147983</td>\n",
       "      <td>3.232275</td>\n",
       "      <td>0.182384</td>\n",
       "      <td>0.181634</td>\n",
       "      <td>0.493583</td>\n",
       "      <td>3.649938</td>\n",
       "      <td>0.374447</td>\n",
       "      <td>0.272021</td>\n",
       "      <td>0.256942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.006778</td>\n",
       "      <td>0.015100</td>\n",
       "      <td>20.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.010144</td>\n",
       "      <td>0.006778</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-19.314130</td>\n",
       "      <td>-1.892311</td>\n",
       "      <td>-0.544550</td>\n",
       "      <td>-0.686546</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.137603</td>\n",
       "      <td>0.224585</td>\n",
       "      <td>27.000000</td>\n",
       "      <td>10089.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>0.222399</td>\n",
       "      <td>0.136155</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-7.370638</td>\n",
       "      <td>-0.381551</td>\n",
       "      <td>0.177677</td>\n",
       "      <td>0.086886</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.224702</td>\n",
       "      <td>0.351585</td>\n",
       "      <td>60.000000</td>\n",
       "      <td>19874.500000</td>\n",
       "      <td>11.000000</td>\n",
       "      <td>0.345777</td>\n",
       "      <td>0.221131</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-4.207797</td>\n",
       "      <td>-0.130132</td>\n",
       "      <td>0.361460</td>\n",
       "      <td>0.261492</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.384900</td>\n",
       "      <td>0.488557</td>\n",
       "      <td>71.000000</td>\n",
       "      <td>29723.750000</td>\n",
       "      <td>13.000000</td>\n",
       "      <td>0.485294</td>\n",
       "      <td>0.378504</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-2.517367</td>\n",
       "      <td>0.096742</td>\n",
       "      <td>0.546601</td>\n",
       "      <td>0.433981</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>0.973050</td>\n",
       "      <td>0.964121</td>\n",
       "      <td>117.000000</td>\n",
       "      <td>39996.000000</td>\n",
       "      <td>48.000000</td>\n",
       "      <td>0.964121</td>\n",
       "      <td>0.973050</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.155115</td>\n",
       "      <td>1.091538</td>\n",
       "      <td>1.372134</td>\n",
       "      <td>1.245182</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       AbstractSent    TitleSent  Abstract_len  __index_level_0__  \\\n",
       "count   5865.000000  5865.000000  10000.000000       10000.000000   \n",
       "mean       0.278236     0.371027     51.060800       19933.388900   \n",
       "std        0.183100     0.184043     21.856529       11451.147983   \n",
       "min        0.006778     0.015100     20.000000           2.000000   \n",
       "25%        0.137603     0.224585     27.000000       10089.000000   \n",
       "50%        0.224702     0.351585     60.000000       19874.500000   \n",
       "75%        0.384900     0.488557     71.000000       29723.750000   \n",
       "max        0.973050     0.964121    117.000000       39996.000000   \n",
       "\n",
       "       headline_length           ctr  emotion_abstract  is_political  \\\n",
       "count     10000.000000  10000.000000      10000.000000  10000.000000   \n",
       "mean         11.051400      0.367295          0.275136      0.420000   \n",
       "std           3.232275      0.182384          0.181634      0.493583   \n",
       "min           2.000000      0.010144          0.006778      0.000000   \n",
       "25%           9.000000      0.222399          0.136155      0.000000   \n",
       "50%          11.000000      0.345777          0.221131      0.000000   \n",
       "75%          13.000000      0.485294          0.378504      1.000000   \n",
       "max          48.000000      0.964121          0.973050      1.000000   \n",
       "\n",
       "         popularity         score  score_no_popularity   score_clean  \n",
       "count  10000.000000  10000.000000         10000.000000  10000.000000  \n",
       "mean      -5.292557     -0.163734             0.365522      0.263143  \n",
       "std        3.649938      0.374447             0.272021      0.256942  \n",
       "min      -19.314130     -1.892311            -0.544550     -0.686546  \n",
       "25%       -7.370638     -0.381551             0.177677      0.086886  \n",
       "50%       -4.207797     -0.130132             0.361460      0.261492  \n",
       "75%       -2.517367      0.096742             0.546601      0.433981  \n",
       "max        0.155115      1.091538             1.372134      1.245182  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>score</th>\n",
       "      <th>ctr</th>\n",
       "      <th>popularity</th>\n",
       "      <th>emotion_abstract</th>\n",
       "      <th>is_political</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>-0.163734</td>\n",
       "      <td>0.367295</td>\n",
       "      <td>-5.292557</td>\n",
       "      <td>0.275136</td>\n",
       "      <td>0.420000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.374447</td>\n",
       "      <td>0.182384</td>\n",
       "      <td>3.649938</td>\n",
       "      <td>0.181634</td>\n",
       "      <td>0.493583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>-1.892311</td>\n",
       "      <td>0.010144</td>\n",
       "      <td>-19.314130</td>\n",
       "      <td>0.006778</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>-0.381551</td>\n",
       "      <td>0.222399</td>\n",
       "      <td>-7.370638</td>\n",
       "      <td>0.136155</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>-0.130132</td>\n",
       "      <td>0.345777</td>\n",
       "      <td>-4.207797</td>\n",
       "      <td>0.221131</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.096742</td>\n",
       "      <td>0.485294</td>\n",
       "      <td>-2.517367</td>\n",
       "      <td>0.378504</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>1.091538</td>\n",
       "      <td>0.964121</td>\n",
       "      <td>0.155115</td>\n",
       "      <td>0.973050</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              score           ctr    popularity  emotion_abstract  \\\n",
       "count  10000.000000  10000.000000  10000.000000      10000.000000   \n",
       "mean      -0.163734      0.367295     -5.292557          0.275136   \n",
       "std        0.374447      0.182384      3.649938          0.181634   \n",
       "min       -1.892311      0.010144    -19.314130          0.006778   \n",
       "25%       -0.381551      0.222399     -7.370638          0.136155   \n",
       "50%       -0.130132      0.345777     -4.207797          0.221131   \n",
       "75%        0.096742      0.485294     -2.517367          0.378504   \n",
       "max        1.091538      0.964121      0.155115          0.973050   \n",
       "\n",
       "       is_political  \n",
       "count  10000.000000  \n",
       "mean       0.420000  \n",
       "std        0.493583  \n",
       "min        0.000000  \n",
       "25%        0.000000  \n",
       "50%        0.000000  \n",
       "75%        1.000000  \n",
       "max        1.000000  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df[['score', 'ctr', 'popularity', 'emotion_abstract', 'is_political']].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>score</th>\n",
       "      <th>ctr</th>\n",
       "      <th>popularity</th>\n",
       "      <th>emotion_abstract</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>score</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.009883</td>\n",
       "      <td>0.729619</td>\n",
       "      <td>-0.727575</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ctr</th>\n",
       "      <td>-0.009883</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.509350</td>\n",
       "      <td>0.508720</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>popularity</th>\n",
       "      <td>0.729619</td>\n",
       "      <td>-0.509350</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.997712</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>emotion_abstract</th>\n",
       "      <td>-0.727575</td>\n",
       "      <td>0.508720</td>\n",
       "      <td>-0.997712</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     score       ctr  popularity  emotion_abstract\n",
       "score             1.000000 -0.009883    0.729619         -0.727575\n",
       "ctr              -0.009883  1.000000   -0.509350          0.508720\n",
       "popularity        0.729619 -0.509350    1.000000         -0.997712\n",
       "emotion_abstract -0.727575  0.508720   -0.997712          1.000000"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_df[['score', 'ctr', 'popularity', 'emotion_abstract']].corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run 2sls regression\n",
    "from linearmodels.iv import IV2SLS\n",
    "\n",
    "# score ~ effort | is_central\n",
    "data_train_df['const'] = 1\n",
    "iv = IV2SLS(dependent=data_train_df['score'],\n",
    "            exog=data_train_df[['const']],\n",
    "            endog=data_train_df['popularity'],\n",
    "            instruments=data_train_df['is_political']).fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          IV-2SLS Estimation Summary                          \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                      0.5291\n",
      "Estimator:                    IV-2SLS   Adj. R-squared:                 0.5291\n",
      "No. Observations:               10000   F-statistic:                    144.50\n",
      "Date:                Mon, Aug 04 2025   P-value (F-stat)                0.0000\n",
      "Time:                        21:14:15   Distribution:                  chi2(1)\n",
      "Cov. Estimator:                robust                                         \n",
      "                                                                              \n",
      "                             Parameter Estimates                              \n",
      "==============================================================================\n",
      "            Parameter  Std. Err.     T-stat    P-value    Lower CI    Upper CI\n",
      "------------------------------------------------------------------------------\n",
      "const          0.2631     0.0354     7.4277     0.0000      0.1937      0.3326\n",
      "popularity     0.0807     0.0067     12.021     0.0000      0.0675      0.0938\n",
      "==============================================================================\n",
      "\n",
      "Endogenous: popularity\n",
      "Instruments: is_political\n",
      "Robust Covariance (Heteroskedastic)\n",
      "Debiased: False\n"
     ]
    }
   ],
   "source": [
    "print(iv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08065608794971979"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# print iv coefs\n",
    "iv.params['popularity']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            OLS Regression Results                            \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                       0.709\n",
      "Model:                            OLS   Adj. R-squared:                  0.709\n",
      "Method:                 Least Squares   F-statistic:                 1.218e+04\n",
      "Date:                Mon, 04 Aug 2025   Prob (F-statistic):               0.00\n",
      "Time:                        21:14:15   Log-Likelihood:                 1807.2\n",
      "No. Observations:               10000   AIC:                            -3608.\n",
      "Df Residuals:                    9997   BIC:                            -3587.\n",
      "Df Model:                           2                                         \n",
      "Covariance Type:            nonrobust                                         \n",
      "==============================================================================\n",
      "                 coef    std err          t      P>|t|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------\n",
      "const         -0.0008      0.005     -0.180      0.857      -0.010       0.008\n",
      "popularity     0.1004      0.001    156.076      0.000       0.099       0.102\n",
      "ctr            1.0029      0.013     77.921      0.000       0.978       1.028\n",
      "==============================================================================\n",
      "Omnibus:                        6.464   Durbin-Watson:                   1.974\n",
      "Prob(Omnibus):                  0.039   Jarque-Bera (JB):                6.457\n",
      "Skew:                          -0.054   Prob(JB):                       0.0396\n",
      "Kurtosis:                       2.940   Cond. No.                         42.6\n",
      "==============================================================================\n",
      "\n",
      "Notes:\n",
      "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
     ]
    }
   ],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm\n",
    "\n",
    "X = data_train_df[['popularity', 'ctr']]\n",
    "X = sm.add_constant(X)\n",
    "y = data_train_df['score']\n",
    "model = sm.OLS(y, X)\n",
    "results = model.fit()\n",
    "print(results.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            OLS Regression Results                            \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                       0.532\n",
      "Model:                            OLS   Adj. R-squared:                  0.532\n",
      "Method:                 Least Squares   F-statistic:                 1.138e+04\n",
      "Date:                Mon, 04 Aug 2025   Prob (F-statistic):               0.00\n",
      "Time:                        21:14:15   Log-Likelihood:                -565.71\n",
      "No. Observations:               10000   AIC:                             1135.\n",
      "Df Residuals:                    9998   BIC:                             1150.\n",
      "Df Model:                           1                                         \n",
      "Covariance Type:            nonrobust                                         \n",
      "==============================================================================\n",
      "                 coef    std err          t      P>|t|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------\n",
      "const          0.2324      0.005     51.526      0.000       0.224       0.241\n",
      "popularity     0.0749      0.001    106.682      0.000       0.073       0.076\n",
      "==============================================================================\n",
      "Omnibus:                        8.795   Durbin-Watson:                   1.977\n",
      "Prob(Omnibus):                  0.012   Jarque-Bera (JB):                8.783\n",
      "Skew:                           0.072   Prob(JB):                       0.0124\n",
      "Kurtosis:                       3.015   Cond. No.                         11.5\n",
      "==============================================================================\n",
      "\n",
      "Notes:\n",
      "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
     ]
    }
   ],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm\n",
    "\n",
    "X = data_train_df['popularity']\n",
    "X = sm.add_constant(X)\n",
    "y = data_train_df['score']\n",
    "model = sm.OLS(y, X)\n",
    "results = model.fit()\n",
    "print(results.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                            OLS Regression Results                            \n",
      "==============================================================================\n",
      "Dep. Variable:                  score   R-squared:                       0.000\n",
      "Model:                            OLS   Adj. R-squared:                 -0.000\n",
      "Method:                 Least Squares   F-statistic:                    0.9766\n",
      "Date:                Mon, 04 Aug 2025   Prob (F-statistic):              0.323\n",
      "Time:                        21:14:15   Log-Likelihood:                -4365.3\n",
      "No. Observations:               10000   AIC:                             8735.\n",
      "Df Residuals:                    9998   BIC:                             8749.\n",
      "Df Model:                           1                                         \n",
      "Covariance Type:            nonrobust                                         \n",
      "==============================================================================\n",
      "                 coef    std err          t      P>|t|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------\n",
      "const         -0.1563      0.008    -18.561      0.000      -0.173      -0.140\n",
      "ctr           -0.0203      0.021     -0.988      0.323      -0.061       0.020\n",
      "==============================================================================\n",
      "Omnibus:                      415.524   Durbin-Watson:                   2.003\n",
      "Prob(Omnibus):                  0.000   Jarque-Bera (JB):              479.395\n",
      "Skew:                          -0.491   Prob(JB):                    7.96e-105\n",
      "Kurtosis:                       3.430   Cond. No.                         6.25\n",
      "==============================================================================\n",
      "\n",
      "Notes:\n",
      "[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
     ]
    }
   ],
   "source": [
    "# linear regression of score vs ctr using statsmodels\n",
    "import statsmodels.api as sm\n",
    "\n",
    "X = data_train_df['ctr']\n",
    "X = sm.add_constant(X)\n",
    "y = data_train_df['score']\n",
    "model = sm.OLS(y, X)\n",
    "results = model.fit()\n",
    "print(results.summary())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-up",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
