{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from toolz import memoize\n",
    "import datetime\n",
    "import math\n",
    "\n",
    "from tqdm import tqdm\n",
    "from src.utils.mouselab_jas import MouselabJas\n",
    "from src.utils.distributions import Normal, expectation\n",
    "from src.utils.env_creation import create_tree, create_init\n",
    "from src.utils.env_export import create_json\n",
    "from src.utils.data_classes import MouselabConfig, Action\n",
    "from simulation import run_simulation\n",
    "from src.policy.jas_voc_policy import JAS_voc_policy\n",
    "from src.policy.jas_policy import RandomPolicy, ExhaustivePolicy, RandomNPolicy\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "from src.utils.utils import sigma_to_tau\n",
    "import numpy as np\n",
    "from src.utils.env_export import format_payoff\n",
    "from src.utils.khalili_env import get_env\n",
    "\n",
    "\n",
    "sns.set(rc={'figure.figsize':(10,6)})\n",
    "sns.set(font_scale=1.5)\n",
    "sns.set_theme()\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = json.load(open(\"./data/dataclips_5.json\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_index = data[\"fields\"].index(\"language\")\n",
    "response_data_index = data[\"fields\"].index(\"datastring\")\n",
    "begin_index = data[\"fields\"].index(\"beginhit\")\n",
    "end_index = data[\"fields\"].index(\"endhit\")\n",
    "\n",
    "f = '%Y-%m-%d %H:%M:%S.%f'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fixed from paper\n",
    "env, config = get_env(5, term_belief=False)\n",
    "voc_policy = JAS_voc_policy(discrete_observations=True, cost_weight=0.5798921379230035)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_action(project: int, criteria: int, expert: int, config: MouselabConfig) -> Action:\n",
    "    query = (1 + criteria) + project*config.num_criterias\n",
    "    action = Action(expert=expert, query=query)\n",
    "    return action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "385it [00:00, 1694.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exclude 168 not really\n",
      "Exclude 248 not really\n",
      "Exclude 283 no\n",
      "Quiz exclusions 10\n",
      "Effort question exclusions 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Check survey responses, filter participants that respond \"No\" to the attention question\n",
    "\n",
    "exclude_answers = [\"no\"]\n",
    "known_workers = []\n",
    "excluded = []\n",
    "quiz_excluded = 0\n",
    "participant_age = []\n",
    "participant_gender = []\n",
    "participant_durations = []\n",
    "\n",
    "# Parse raw mturk data into dataframe\n",
    "for p_index, p_data in tqdm(enumerate(data[\"values\"])):\n",
    "    # Filter out empty responses\n",
    "    response_data = p_data[response_data_index]\n",
    "    if (response_data != None):\n",
    "        p_res_obj = json.loads(response_data)\n",
    "        if (not p_res_obj[\"workerId\"].startswith(\"debug\")):\n",
    "            worker = p_index #p_res_obj[\"workerId\"]#p_index # \n",
    "            if worker in known_workers:\n",
    "                print(\"Duplicate worker\", worker)\n",
    "            else: \n",
    "                known_workers.append(worker)\n",
    "            p_res = p_res_obj[\"data\"]\n",
    "            if \"quiz_failures\" in p_res_obj[\"questiondata\"].keys():\n",
    "                quiz_failures = p_res_obj[\"questiondata\"][\"quiz_failures\"]\n",
    "            else:\n",
    "                quiz_failures = 0\n",
    "            completed_test_trials = sum([1 if ('trial_id' in  p_res[i]['trialdata'].keys() and p_res[i]['trialdata']['trial_id'].startswith(\"test\")) else 0 for i in range(len(p_res))])\n",
    "            if quiz_failures >= 3:\n",
    "                quiz_excluded += 1\n",
    "            elif (quiz_failures < 3) and (completed_test_trials == 10):\n",
    "                for i in range(len(p_res)):\n",
    "                    # Get test trials\n",
    "                    if p_res[i]['trialdata'][\"trial_type\"] == \"survey-text\":\n",
    "                        attention_response: str = p_res[i]['trialdata'][\"response\"][\"Reward\"].lower()\n",
    "                        if any([attention_response.startswith(answer) for answer in exclude_answers]):\n",
    "                            print(\"Exclude\", worker, attention_response)\n",
    "                            excluded.append(worker)\n",
    "                        else:\n",
    "                            participant_age.append(p_res[i]['trialdata'][\"response\"][\"Age\"].lower())\n",
    "                            participant_gender.append(p_res[i]['trialdata'][\"response\"][\"Gender\"].lower())\n",
    "                            if p_data[begin_index] and p_data[end_index]:\n",
    "                                begin = datetime.datetime.strptime(p_data[begin_index], f)\n",
    "                                end = datetime.datetime.strptime(p_data[end_index], f)\n",
    "                                duration = (end - begin).total_seconds()\n",
    "                                participant_durations.append(duration)\n",
    "\n",
    "print(\"Quiz exclusions\", quiz_excluded)\n",
    "print(\"Effort question exclusions\", len(excluded))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "29.08053691275168"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def calculate_mean_age(ages):\n",
    "    valid_ages = 0\n",
    "    total_age = 0\n",
    "    for age in ages:\n",
    "        if age.isdigit():\n",
    "            valid_ages += 1\n",
    "            total_age += int(age)\n",
    "        else:\n",
    "            print(\"Error\", age)\n",
    "    return total_age / valid_ages\n",
    "calculate_mean_age(participant_age)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(147, 148, 3)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def count_gender(genders):\n",
    "    male = 0\n",
    "    female = 0\n",
    "    nb = 0\n",
    "    for gender in genders:\n",
    "        if gender.lower().strip() in [\"male\", \"m\", \"man\", \"men\", \"masculin\", \"masculine\"]:\n",
    "            male += 1\n",
    "        elif gender.lower().strip() in [\"female\", \"f\", \"woman\", \"women\", \"feminine\", \"i am female\", \"femenine\"]:\n",
    "            female += 1\n",
    "        elif gender.lower().strip() in [\"non binary\", \"non-binary\"]:\n",
    "            nb += 1\n",
    "        else:\n",
    "            print(\"Error\", gender)\n",
    "    return male, female, nb\n",
    "count_gender(participant_gender)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22.013963583333336 10.902171210172384\n"
     ]
    }
   ],
   "source": [
    "median_duration = np.median(participant_durations) / 60\n",
    "pay = 3.5\n",
    "bonus = 0.5\n",
    "pay_per_minute = (pay+bonus) / median_duration\n",
    "pay_per_hour = pay_per_minute * 60\n",
    "print(median_duration, pay_per_hour)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "385it [39:18,  6.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Responses: 298\n",
      "Responseswith excluded : 301\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "df_index = [\"Participant\", \"Condition\", \"TrialId\", \"Score\", \"ExpectedScore\", \"NumClicks\", \"Actions\", \"Selection\",\n",
    "        \"Seed\", \"ClickAgreement\"]    \n",
    "df_all_data = [] # With excluded to calculate bonus payments\n",
    "df_data = []\n",
    "\n",
    "bonus_data = {}\n",
    "known_workers = []\n",
    "good_responses = 0\n",
    "demographics = []\n",
    "\n",
    "participant_actions = []\n",
    "\n",
    "# Parse raw mturk data into dataframe\n",
    "for p_index, p_data in tqdm(enumerate(data[\"values\"])):\n",
    "    # Filter out empty responses\n",
    "    language = p_data[language_index]\n",
    "    response_data = p_data[response_data_index]\n",
    "    if p_data[begin_index] and p_data[end_index]:\n",
    "        begin = datetime.datetime.strptime(p_data[begin_index], f)\n",
    "        end = datetime.datetime.strptime(p_data[end_index], f)\n",
    "        duration = (end - begin).total_seconds()\n",
    "    else:\n",
    "        duration = None\n",
    "    if (response_data != None):\n",
    "        p_res_obj = json.loads(response_data)\n",
    "        if (not p_res_obj[\"workerId\"].startswith(\"debug\")):\n",
    "            condition = p_res_obj[\"condition\"]\n",
    "            worker = p_index #p_res_obj[\"workerId\"]# \n",
    "            if worker in known_workers:\n",
    "                print(\"Duplicate worker\", worker)\n",
    "            else: \n",
    "                known_workers.append(worker)\n",
    "            p_res = p_res_obj[\"data\"]\n",
    "            participant_responses = []\n",
    "            if \"quiz_failures\" in p_res_obj[\"questiondata\"].keys():\n",
    "                quiz_failures = p_res_obj[\"questiondata\"][\"quiz_failures\"]\n",
    "            else:\n",
    "                quiz_failures = 0\n",
    "            if \"final_bonus\" in p_res_obj[\"questiondata\"].keys():\n",
    "                bonus =  p_res_obj[\"questiondata\"][\"final_bonus\"]\n",
    "            else:\n",
    "                bonus = 0\n",
    "            participant_survey = {\"Participant\": worker, \"Condition\": condition, \"Language\": language, \"QuizAttempts\": 0, \"QuizFailures\": quiz_failures, \"Bonus\": bonus, \"Duration\": duration}\n",
    "            completed_test_trials = sum([1 if ('trial_id' in  p_res[i]['trialdata'].keys() and p_res[i]['trialdata']['trial_id'].startswith(\"test\")) else 0 for i in range(len(p_res))])\n",
    "            if (quiz_failures < 3) and (completed_test_trials == 10):\n",
    "                for i in range(len(p_res)):\n",
    "                    # Get test trials\n",
    "                    if 'trial_id' in p_res[i]['trialdata'].keys() and p_res[i]['trialdata']['trial_id'].startswith(\"test\"):\n",
    "                        trial_id = p_res[i]['trialdata']['trial_id']\n",
    "                        seed = int(p_res[i]['trialdata']['seed'])\n",
    "                        ground_truth = p_res[i]['trialdata']['ground_truth']\n",
    "                        # project, criteria, expert\n",
    "                        clicks = p_res[i]['trialdata'][\"clicks\"]\n",
    "                        num_clicks = len(clicks)\n",
    "                        selected_project = p_res[i]['trialdata'][\"selected_project\"]\n",
    "                        term_reward = p_res[i]['trialdata'][\"reward\"]\n",
    "                        expected_reward = p_res[i]['trialdata'][\"expected_reward\"]\n",
    "                        # Calculate real env rewards\n",
    "                        env.reset(seed=seed)\n",
    "                        actions = [convert_action(*click, config) for click in clicks]\n",
    "                        participant_actions.extend(actions)\n",
    "                        cost = 0\n",
    "                        # Click agreement\n",
    "                        click_agreement = []\n",
    "                        for action in actions:\n",
    "                            optimal_actions = voc_policy.get_best_actions(env, eps=0.001)\n",
    "                            if action in optimal_actions:\n",
    "                                click_agreement.append(1)\n",
    "                            else:\n",
    "                                click_agreement.append(0)\n",
    "                            _, reward, _, _ = env.step(action)\n",
    "                            cost += reward\n",
    "                        optimal_actions = voc_policy.get_best_actions(env)\n",
    "                        if env.term_action in optimal_actions:\n",
    "                            click_agreement.append(1)\n",
    "                        else:\n",
    "                            click_agreement.append(0)\n",
    "                        path = np.array(range(1, config.num_criterias+1))+(selected_project*config.num_criterias)\n",
    "                        env_expected_reward = cost + env.expected_path_value(path, env.state)\n",
    "                        env_term_reward = cost + env.path_value(path)\n",
    "                        assert np.all(np.isclose(ground_truth, env.ground_truth.tolist()))\n",
    "                        assert np.all(np.isclose(np.array(format_payoff(config.num_projects, config.num_criterias, env.expert_truths.tolist())), p_res[i]['trialdata']['payoff_matrix']))\n",
    "                        assert np.isclose(term_reward, env_term_reward)\n",
    "                        assert np.isclose(expected_reward, env_expected_reward)\n",
    "                        \n",
    "                        #[\"Participant\", \"Condition\", \"TrialId\", \"Score\", \"ExpectedScore\", \"NumClicks\", \"Actions\", \"Selection\", \"Seed\", \"ClickAgreement\"]  \n",
    "                        if worker not in excluded:\n",
    "                            df_data.append([worker, condition, trial_id, term_reward, expected_reward, num_clicks, clicks, selected_project, seed, np.mean(click_agreement)])\n",
    "                        df_all_data.append([worker, condition, trial_id, term_reward, expected_reward, num_clicks, clicks, selected_project, seed, np.mean(click_agreement)])\n",
    "                    elif 'trial_id' in p_res[i]['trialdata'].keys() and p_res[i]['trialdata']['trial_id'].startswith(\"train\"):\n",
    "                        pass\n",
    "                    elif p_res[i]['trialdata'][\"trial_type\"] == \"survey-text\":\n",
    "                        pass\n",
    "                    \n",
    "\n",
    "df = pd.DataFrame(df_data, columns=df_index)\n",
    "df_all = pd.DataFrame(df_all_data, columns=df_index)\n",
    "print(\"Responses:\", len(df[\"Participant\"].unique()))\n",
    "print(\"Responseswith excluded :\", len(df_all[\"Participant\"].unique()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Condition</th>\n",
       "      <th>Score</th>\n",
       "      <th>ExpectedScore</th>\n",
       "      <th>NumClicks</th>\n",
       "      <th>Selection</th>\n",
       "      <th>Seed</th>\n",
       "      <th>ClickAgreement</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Participant</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3.887710</td>\n",
       "      <td>3.634325</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.4</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.316667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>3.582799</td>\n",
       "      <td>3.465661</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1.5</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.175000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.0</td>\n",
       "      <td>3.819524</td>\n",
       "      <td>3.548929</td>\n",
       "      <td>3.4</td>\n",
       "      <td>2.0</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.250000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.0</td>\n",
       "      <td>3.208749</td>\n",
       "      <td>3.418823</td>\n",
       "      <td>3.4</td>\n",
       "      <td>1.9</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.115000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3.873925</td>\n",
       "      <td>3.741854</td>\n",
       "      <td>4.7</td>\n",
       "      <td>2.1</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.683333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>379</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3.735126</td>\n",
       "      <td>3.700479</td>\n",
       "      <td>4.4</td>\n",
       "      <td>1.7</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.430000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>380</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3.717647</td>\n",
       "      <td>3.670769</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.1</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.583333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>381</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3.680862</td>\n",
       "      <td>3.610198</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.6</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.383333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>383</th>\n",
       "      <td>1.0</td>\n",
       "      <td>3.931509</td>\n",
       "      <td>3.668889</td>\n",
       "      <td>4.8</td>\n",
       "      <td>2.5</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.316667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>384</th>\n",
       "      <td>2.0</td>\n",
       "      <td>3.358313</td>\n",
       "      <td>3.350739</td>\n",
       "      <td>2.1</td>\n",
       "      <td>1.1</td>\n",
       "      <td>16.5</td>\n",
       "      <td>0.083333</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>298 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "             Condition     Score  ExpectedScore  NumClicks  Selection  Seed  \\\n",
       "Participant                                                                   \n",
       "0                  0.0  3.887710       3.634325        5.0        2.4  16.5   \n",
       "1                  1.0  3.582799       3.465661        4.0        1.5  16.5   \n",
       "2                  2.0  3.819524       3.548929        3.4        2.0  16.5   \n",
       "4                  2.0  3.208749       3.418823        3.4        1.9  16.5   \n",
       "6                  0.0  3.873925       3.741854        4.7        2.1  16.5   \n",
       "...                ...       ...            ...        ...        ...   ...   \n",
       "379                0.0  3.735126       3.700479        4.4        1.7  16.5   \n",
       "380                0.0  3.717647       3.670769        5.0        2.1  16.5   \n",
       "381                0.0  3.680862       3.610198        5.0        2.6  16.5   \n",
       "383                1.0  3.931509       3.668889        4.8        2.5  16.5   \n",
       "384                2.0  3.358313       3.350739        2.1        1.1  16.5   \n",
       "\n",
       "             ClickAgreement  \n",
       "Participant                  \n",
       "0                  0.316667  \n",
       "1                  0.175000  \n",
       "2                  0.250000  \n",
       "4                  0.115000  \n",
       "6                  0.683333  \n",
       "...                     ...  \n",
       "379                0.430000  \n",
       "380                0.583333  \n",
       "381                0.383333  \n",
       "383                0.316667  \n",
       "384                0.083333  \n",
       "\n",
       "[298 rows x 7 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "df.groupby(\"Participant\").mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"2\" halign=\"left\">Condition</th>\n",
       "      <th colspan=\"2\" halign=\"left\">Score</th>\n",
       "      <th colspan=\"2\" halign=\"left\">ExpectedScore</th>\n",
       "      <th colspan=\"2\" halign=\"left\">NumClicks</th>\n",
       "      <th colspan=\"2\" halign=\"left\">Selection</th>\n",
       "      <th colspan=\"2\" halign=\"left\">Seed</th>\n",
       "      <th colspan=\"2\" halign=\"left\">ClickAgreement</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Participant</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>293</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.198191</td>\n",
       "      <td>0.521779</td>\n",
       "      <td>3.660236</td>\n",
       "      <td>0.107887</td>\n",
       "      <td>4.5</td>\n",
       "      <td>0.527046</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.333333</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>0.169967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.171736</td>\n",
       "      <td>0.540657</td>\n",
       "      <td>3.729907</td>\n",
       "      <td>0.037723</td>\n",
       "      <td>4.8</td>\n",
       "      <td>0.632456</td>\n",
       "      <td>1.5</td>\n",
       "      <td>1.269296</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.205067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>225</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.159595</td>\n",
       "      <td>0.541685</td>\n",
       "      <td>3.697053</td>\n",
       "      <td>0.085047</td>\n",
       "      <td>4.9</td>\n",
       "      <td>0.316228</td>\n",
       "      <td>2.4</td>\n",
       "      <td>1.429841</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.490000</td>\n",
       "      <td>0.224461</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>210</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.155702</td>\n",
       "      <td>0.403269</td>\n",
       "      <td>3.714527</td>\n",
       "      <td>0.093303</td>\n",
       "      <td>3.6</td>\n",
       "      <td>0.966092</td>\n",
       "      <td>1.8</td>\n",
       "      <td>1.135292</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.668333</td>\n",
       "      <td>0.268679</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>266</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.148768</td>\n",
       "      <td>0.416167</td>\n",
       "      <td>3.724769</td>\n",
       "      <td>0.030657</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.5</td>\n",
       "      <td>1.433721</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.260579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>226</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.068694</td>\n",
       "      <td>0.899954</td>\n",
       "      <td>3.369509</td>\n",
       "      <td>0.113367</td>\n",
       "      <td>2.2</td>\n",
       "      <td>0.632456</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.483046</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.040000</td>\n",
       "      <td>0.126491</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>215</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.063188</td>\n",
       "      <td>0.444662</td>\n",
       "      <td>3.029154</td>\n",
       "      <td>0.283073</td>\n",
       "      <td>4.9</td>\n",
       "      <td>0.316228</td>\n",
       "      <td>1.9</td>\n",
       "      <td>1.523884</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.490000</td>\n",
       "      <td>0.285038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>361</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.057655</td>\n",
       "      <td>1.130552</td>\n",
       "      <td>3.406485</td>\n",
       "      <td>0.016826</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0.948683</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.490712</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.020367</td>\n",
       "      <td>1.194276</td>\n",
       "      <td>3.401241</td>\n",
       "      <td>0.007475</td>\n",
       "      <td>1.6</td>\n",
       "      <td>0.699206</td>\n",
       "      <td>1.3</td>\n",
       "      <td>1.159502</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>370</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.851281</td>\n",
       "      <td>0.935422</td>\n",
       "      <td>3.050600</td>\n",
       "      <td>0.304528</td>\n",
       "      <td>3.4</td>\n",
       "      <td>1.173788</td>\n",
       "      <td>2.4</td>\n",
       "      <td>1.349897</td>\n",
       "      <td>16.5</td>\n",
       "      <td>3.02765</td>\n",
       "      <td>0.330000</td>\n",
       "      <td>0.204547</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>301 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            Condition          Score           ExpectedScore            \\\n",
       "                 mean  std      mean       std          mean       std   \n",
       "Participant                                                              \n",
       "293               1.0  0.0  4.198191  0.521779      3.660236  0.107887   \n",
       "75                2.0  0.0  4.171736  0.540657      3.729907  0.037723   \n",
       "225               1.0  0.0  4.159595  0.541685      3.697053  0.085047   \n",
       "210               1.0  0.0  4.155702  0.403269      3.714527  0.093303   \n",
       "266               1.0  0.0  4.148768  0.416167      3.724769  0.030657   \n",
       "...               ...  ...       ...       ...           ...       ...   \n",
       "226               0.0  0.0  3.068694  0.899954      3.369509  0.113367   \n",
       "215               0.0  0.0  3.063188  0.444662      3.029154  0.283073   \n",
       "361               0.0  0.0  3.057655  1.130552      3.406485  0.016826   \n",
       "90                2.0  0.0  3.020367  1.194276      3.401241  0.007475   \n",
       "370               2.0  0.0  2.851281  0.935422      3.050600  0.304528   \n",
       "\n",
       "            NumClicks           Selection            Seed           \\\n",
       "                 mean       std      mean       std  mean      std   \n",
       "Participant                                                          \n",
       "293               4.5  0.527046       2.0  1.333333  16.5  3.02765   \n",
       "75                4.8  0.632456       1.5  1.269296  16.5  3.02765   \n",
       "225               4.9  0.316228       2.4  1.429841  16.5  3.02765   \n",
       "210               3.6  0.966092       1.8  1.135292  16.5  3.02765   \n",
       "266               5.0  0.000000       2.5  1.433721  16.5  3.02765   \n",
       "...               ...       ...       ...       ...   ...      ...   \n",
       "226               2.2  0.632456       0.7  0.483046  16.5  3.02765   \n",
       "215               4.9  0.316228       1.9  1.523884  16.5  3.02765   \n",
       "361               0.3  0.948683       2.0  1.490712  16.5  3.02765   \n",
       "90                1.6  0.699206       1.3  1.159502  16.5  3.02765   \n",
       "370               3.4  1.173788       2.4  1.349897  16.5  3.02765   \n",
       "\n",
       "            ClickAgreement            \n",
       "                      mean       std  \n",
       "Participant                           \n",
       "293               0.266667  0.169967  \n",
       "75                0.375000  0.205067  \n",
       "225               0.490000  0.224461  \n",
       "210               0.668333  0.268679  \n",
       "266               0.500000  0.260579  \n",
       "...                    ...       ...  \n",
       "226               0.040000  0.126491  \n",
       "215               0.490000  0.285038  \n",
       "361               0.000000  0.000000  \n",
       "90                0.000000  0.000000  \n",
       "370               0.330000  0.204547  \n",
       "\n",
       "[301 rows x 14 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "participant_scores = df_all.groupby(\"Participant\").agg([\"mean\", \"std\"])\n",
    "participant_scores = participant_scores.sort_values((\"Score\", \"mean\"), ascending=False)\n",
    "participant_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "293,0.75\n",
      "75,0.75\n",
      "225,0.75\n",
      "210,0.75\n",
      "266,0.75\n",
      "7,0.75\n",
      "205,0.75\n",
      "44,0.75\n",
      "250,0.75\n",
      "31,0.75\n",
      "312,0.75\n",
      "344,0.75\n",
      "324,0.75\n",
      "307,0.75\n",
      "165,0.75\n",
      "333,0.75\n",
      "82,0.75\n",
      "28,0.75\n",
      "240,0.75\n",
      "358,0.75\n",
      "85,0.75\n",
      "280,0.75\n",
      "180,0.75\n",
      "172,0.75\n",
      "217,0.75\n",
      "276,0.75\n",
      "195,0.75\n",
      "367,0.75\n",
      "354,0.75\n",
      "15,0.75\n",
      "302,0.75\n",
      "359,0.75\n",
      "308,0.75\n",
      "321,0.75\n",
      "89,0.75\n",
      "216,0.75\n",
      "288,0.75\n",
      "65,0.75\n",
      "212,0.75\n",
      "118,0.75\n",
      "51,0.75\n",
      "178,0.75\n",
      "181,0.75\n",
      "126,0.75\n",
      "234,0.75\n",
      "24,0.75\n",
      "166,0.75\n",
      "84,0.75\n",
      "222,0.75\n",
      "204,0.75\n",
      "34,0.75\n",
      "56,0.75\n",
      "223,0.75\n",
      "96,0.75\n",
      "29,0.75\n",
      "338,0.75\n",
      "262,0.75\n",
      "48,0.75\n",
      "151,0.75\n",
      "268,0.75\n",
      "383,0.75\n",
      "171,0.75\n",
      "373,0.75\n",
      "169,0.75\n",
      "260,0.75\n",
      "335,0.75\n",
      "343,0.75\n",
      "197,0.75\n",
      "190,0.75\n",
      "346,0.75\n",
      "186,0.75\n",
      "349,0.75\n",
      "193,0.75\n",
      "112,0.75\n",
      "130,0.75\n",
      "177,0.75\n",
      "340,0.75\n",
      "52,0.75\n",
      "9,0.75\n",
      "364,0.75\n",
      "313,0.75\n",
      "120,0.75\n",
      "184,0.75\n",
      "156,0.75\n",
      "331,0.75\n",
      "162,0.75\n",
      "0,0.75\n",
      "272,0.75\n",
      "353,0.75\n",
      "140,0.75\n",
      "11,0.75\n",
      "6,0.75\n",
      "122,0.75\n",
      "378,0.75\n",
      "257,0.75\n",
      "237,0.75\n",
      "78,0.75\n",
      "264,0.75\n",
      "125,0.75\n",
      "294,0.75\n",
      "183,0.75\n",
      "131,0.75\n",
      "301,0.75\n",
      "279,0.75\n",
      "214,0.75\n",
      "352,0.75\n",
      "128,0.75\n",
      "285,0.75\n",
      "98,0.75\n",
      "292,0.75\n",
      "368,0.75\n",
      "213,0.75\n",
      "366,0.75\n",
      "2,0.75\n",
      "164,0.75\n",
      "305,0.75\n",
      "191,0.75\n",
      "345,0.75\n",
      "66,0.75\n",
      "16,0.75\n",
      "148,0.75\n",
      "258,0.75\n",
      "374,0.75\n",
      "220,0.75\n",
      "32,0.75\n",
      "137,0.75\n",
      "297,0.75\n",
      "372,0.75\n",
      "256,0.75\n",
      "330,0.75\n",
      "199,0.75\n",
      "21,0.75\n",
      "265,0.75\n",
      "259,0.75\n",
      "254,0.75\n",
      "109,0.75\n",
      "233,0.75\n",
      "154,0.75\n",
      "261,0.75\n",
      "13,0.75\n",
      "86,0.75\n",
      "22,0.75\n",
      "196,0.75\n",
      "203,0.75\n",
      "91,0.75\n",
      "117,0.75\n",
      "209,0.75\n",
      "159,0.75\n",
      "244,0.75\n",
      "229,0.75\n",
      "269,0.75\n",
      "194,0.25\n",
      "275,0.25\n",
      "379,0.25\n",
      "114,0.25\n",
      "135,0.25\n",
      "143,0.25\n",
      "360,0.25\n",
      "375,0.25\n",
      "304,0.25\n",
      "380,0.25\n",
      "201,0.25\n",
      "342,0.25\n",
      "239,0.25\n",
      "255,0.25\n",
      "300,0.25\n",
      "242,0.25\n",
      "291,0.25\n",
      "231,0.25\n",
      "369,0.25\n",
      "283,0.25\n",
      "170,0.25\n",
      "271,0.25\n",
      "290,0.25\n",
      "274,0.25\n",
      "381,0.25\n",
      "278,0.25\n",
      "168,0.25\n",
      "355,0.25\n",
      "362,0.25\n",
      "87,0.25\n",
      "311,0.25\n",
      "94,0.25\n",
      "42,0.25\n",
      "189,0.25\n",
      "176,0.25\n",
      "235,0.25\n",
      "236,0.25\n",
      "55,0.25\n",
      "59,0.25\n",
      "232,0.25\n",
      "192,0.25\n",
      "243,0.25\n",
      "157,0.25\n",
      "155,0.25\n",
      "337,0.25\n",
      "152,0.25\n",
      "45,0.25\n",
      "318,0.25\n",
      "27,0.25\n",
      "167,0.25\n",
      "228,0.25\n",
      "315,0.25\n",
      "282,0.25\n",
      "1,0.25\n",
      "115,0.25\n",
      "43,0.25\n",
      "303,0.25\n",
      "341,0.25\n",
      "12,0.25\n",
      "241,0.25\n",
      "132,0.25\n",
      "296,0.25\n",
      "108,0.25\n",
      "273,0.25\n",
      "37,0.25\n",
      "310,0.25\n",
      "270,0.25\n",
      "101,0.25\n",
      "211,0.25\n",
      "136,0.25\n",
      "323,0.25\n",
      "246,0.25\n",
      "123,0.25\n",
      "46,0.25\n",
      "219,0.25\n",
      "284,0.25\n",
      "376,0.25\n",
      "317,0.25\n",
      "33,0.25\n",
      "185,0.25\n",
      "327,0.25\n",
      "111,0.25\n",
      "77,0.25\n",
      "36,0.25\n",
      "179,0.25\n",
      "23,0.25\n",
      "363,0.25\n",
      "173,0.25\n",
      "146,0.25\n",
      "121,0.25\n",
      "57,0.25\n",
      "328,0.25\n",
      "334,0.25\n",
      "336,0.25\n",
      "47,0.25\n",
      "149,0.25\n",
      "289,0.25\n",
      "326,0.25\n",
      "329,0.25\n",
      "88,0.25\n",
      "322,0.25\n",
      "145,0.25\n",
      "319,0.25\n",
      "377,0.25\n",
      "263,0.25\n",
      "142,0.25\n",
      "116,0.25\n",
      "161,0.25\n",
      "20,0.25\n",
      "384,0.25\n",
      "54,0.25\n",
      "332,0.25\n",
      "39,0.25\n",
      "253,0.25\n",
      "119,0.25\n",
      "10,0.25\n",
      "206,0.25\n",
      "158,0.25\n",
      "295,0.25\n",
      "105,0.25\n",
      "320,0.25\n",
      "175,0.25\n",
      "287,0.25\n",
      "267,0.25\n",
      "63,0.25\n",
      "306,0.25\n",
      "281,0.25\n",
      "298,0.25\n",
      "249,0.25\n",
      "129,0.25\n",
      "49,0.25\n",
      "286,0.25\n",
      "316,0.25\n",
      "150,0.25\n",
      "200,0.25\n",
      "245,0.25\n",
      "103,0.25\n",
      "357,0.25\n",
      "248,0.25\n",
      "4,0.25\n",
      "230,0.25\n",
      "69,0.25\n",
      "8,0.25\n",
      "30,0.25\n",
      "238,0.25\n",
      "226,0.25\n",
      "215,0.25\n",
      "361,0.25\n",
      "90,0.25\n",
      "370,0.25\n"
     ]
    }
   ],
   "source": [
    "sorted_participants = participant_scores.index.tolist()\n",
    "half_participants = math.ceil(len(sorted_participants)/2)\n",
    "high_bonus = sorted_participants[:half_participants]\n",
    "low_bonus = sorted_participants[half_participants:]\n",
    "for id in high_bonus:\n",
    "    print(f\"{id},0.75\")\n",
    "for id in low_bonus:\n",
    "    print(f\"{id},0.25\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"./data/experiment_results/exp_5.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jas",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
