{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import json\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stat_unique(data: pd.DataFrame, key):\n",
    "    if key is None:\n",
    "        print('Total length: {}'.format(len(data)))\n",
    "    elif isinstance(key, str):\n",
    "        print('Number of unique {}: {}'.format(key, len(data[key].unique())))\n",
    "        return len(data[key].unique())\n",
    "    elif isinstance(key, list):\n",
    "        print('Number of unique [{}]: {}'.format(','.join(key), len(data.drop_duplicates(key, keep='first'))))\n",
    "        return len(data.drop_duplicates(key, keep='first'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\envs\\py36\\lib\\site-packages\\IPython\\core\\interactiveshell.py:3072: DtypeWarning: Columns (17) have mixed types.Specify dtype option on import or set low_memory=False.\n",
      "  interactivity=interactivity, compiler=compiler, result=result)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>order_id</th>\n",
       "      <th>assignment_id</th>\n",
       "      <th>user_id</th>\n",
       "      <th>assistment_id</th>\n",
       "      <th>problem_id</th>\n",
       "      <th>original</th>\n",
       "      <th>correct</th>\n",
       "      <th>attempt_count</th>\n",
       "      <th>ms_first_response</th>\n",
       "      <th>tutor_mode</th>\n",
       "      <th>...</th>\n",
       "      <th>hint_count</th>\n",
       "      <th>hint_total</th>\n",
       "      <th>overlap_time</th>\n",
       "      <th>template_id</th>\n",
       "      <th>answer_id</th>\n",
       "      <th>answer_text</th>\n",
       "      <th>first_action</th>\n",
       "      <th>bottom_hint</th>\n",
       "      <th>opportunity</th>\n",
       "      <th>opportunity_original</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>33022537</td>\n",
       "      <td>277618</td>\n",
       "      <td>64525</td>\n",
       "      <td>33139</td>\n",
       "      <td>51424</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>32454</td>\n",
       "      <td>tutor</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>32454</td>\n",
       "      <td>30799</td>\n",
       "      <td>NaN</td>\n",
       "      <td>26</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>33022709</td>\n",
       "      <td>277618</td>\n",
       "      <td>64525</td>\n",
       "      <td>33150</td>\n",
       "      <td>51435</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>4922</td>\n",
       "      <td>tutor</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>4922</td>\n",
       "      <td>30799</td>\n",
       "      <td>NaN</td>\n",
       "      <td>55</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>35450204</td>\n",
       "      <td>220674</td>\n",
       "      <td>70363</td>\n",
       "      <td>33159</td>\n",
       "      <td>51444</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>25390</td>\n",
       "      <td>tutor</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>42000</td>\n",
       "      <td>30799</td>\n",
       "      <td>NaN</td>\n",
       "      <td>88</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>35450295</td>\n",
       "      <td>220674</td>\n",
       "      <td>70363</td>\n",
       "      <td>33110</td>\n",
       "      <td>51395</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>4859</td>\n",
       "      <td>tutor</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>4859</td>\n",
       "      <td>30059</td>\n",
       "      <td>NaN</td>\n",
       "      <td>41</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>35450311</td>\n",
       "      <td>220674</td>\n",
       "      <td>70363</td>\n",
       "      <td>33196</td>\n",
       "      <td>51481</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>19813</td>\n",
       "      <td>tutor</td>\n",
       "      <td>...</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>124564</td>\n",
       "      <td>30060</td>\n",
       "      <td>NaN</td>\n",
       "      <td>65</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 30 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   order_id  assignment_id  user_id  assistment_id  problem_id  original  \\\n",
       "0  33022537         277618    64525          33139       51424         1   \n",
       "1  33022709         277618    64525          33150       51435         1   \n",
       "2  35450204         220674    70363          33159       51444         1   \n",
       "3  35450295         220674    70363          33110       51395         1   \n",
       "4  35450311         220674    70363          33196       51481         1   \n",
       "\n",
       "   correct  attempt_count  ms_first_response tutor_mode  ... hint_count  \\\n",
       "0        1              1              32454      tutor  ...          0   \n",
       "1        1              1               4922      tutor  ...          0   \n",
       "2        0              2              25390      tutor  ...          0   \n",
       "3        1              1               4859      tutor  ...          0   \n",
       "4        0             14              19813      tutor  ...          3   \n",
       "\n",
       "   hint_total  overlap_time  template_id answer_id  answer_text first_action  \\\n",
       "0           3         32454        30799       NaN           26            0   \n",
       "1           3          4922        30799       NaN           55            0   \n",
       "2           3         42000        30799       NaN           88            0   \n",
       "3           3          4859        30059       NaN           41            0   \n",
       "4           4        124564        30060       NaN           65            0   \n",
       "\n",
       "  bottom_hint  opportunity  opportunity_original  \n",
       "0         NaN            1                   1.0  \n",
       "1         NaN            2                   2.0  \n",
       "2         NaN            1                   1.0  \n",
       "3         NaN            2                   2.0  \n",
       "4         0.0            3                   3.0  \n",
       "\n",
       "[5 rows x 30 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_data = pd.read_csv('../../data/assistment/assistment2009.csv', encoding = 'ISO-8859-15', dtype={'skill_id': str})\n",
    "raw_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = raw_data.rename(columns={'user_id': 'student_id',\n",
    "                                    'problem_id': 'question_id',\n",
    "                                    'skill_id': 'knowledge_id',\n",
    "                                    'skill_name': 'knowledge_name',\n",
    "                                    })\n",
    "all_data = raw_data.loc[:, ['student_id', 'question_id', 'knowledge_id', 'knowledge_name', 'correct']].dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total length: 325637\n",
      "Number of unique [student_id,question_id]: 270478\n",
      "Number of unique student_id: 4151\n",
      "Number of unique question_id: 16891\n",
      "Number of unique knowledge_id: 111\n"
     ]
    }
   ],
   "source": [
    "stat_unique(all_data, None)\n",
    "a=stat_unique(all_data, ['student_id', 'question_id'])\n",
    "b=stat_unique(all_data, 'student_id')\n",
    "c=stat_unique(all_data, 'question_id')\n",
    "d=stat_unique(all_data, 'knowledge_id')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_data = all_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "filter 15370 questions\n"
     ]
    }
   ],
   "source": [
    "# filter questions\n",
    "n_students = selected_data.groupby('question_id')['student_id'].count()\n",
    "question_filter = n_students[n_students < 50].index.tolist()\n",
    "print(f'filter {len(question_filter)} questions')\n",
    "selected_data = selected_data[~selected_data['question_id'].isin(question_filter)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "filter 2120 students\n"
     ]
    }
   ],
   "source": [
    "# filter students\n",
    "n_questions = selected_data.groupby('student_id')['question_id'].count()\n",
    "student_filter = n_questions[n_questions < 20].index.tolist()\n",
    "print(f'filter {len(student_filter)} students')\n",
    "selected_data = selected_data[~selected_data['student_id'].isin(student_filter)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get question to knowledge map\n",
    "q2k = {}\n",
    "table = selected_data.loc[:, ['question_id', 'knowledge_id']].drop_duplicates()\n",
    "for i, row in table.iterrows():\n",
    "    q = row['question_id']\n",
    "    q2k[q] = set(map(int, str(int(float(row['knowledge_id']))).split('_')))\n",
    "    \n",
    "# get knowledge to question map\n",
    "k2q = {}\n",
    "for q, ks in q2k.items():\n",
    "    for k in ks:\n",
    "        k2q.setdefault(k, set())\n",
    "        k2q[k].add(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "filter 8 knowledges\n"
     ]
    }
   ],
   "source": [
    "# filter knowledges\n",
    "#selected_knowledges = { k for k, q in k2q.items()}\n",
    "selected_knowledges = { k for k, q in k2q.items() if len(q) >= 10}\n",
    "print(f'filter {len(k2q) - len(selected_knowledges)} knowledges')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# update maps\n",
    "q2k = {q : ks for q, ks in q2k.items() if ks & selected_knowledges}\n",
    "k2q = {}\n",
    "for q, ks in q2k.items():\n",
    "    for k in ks:\n",
    "        k2q.setdefault(k, set())\n",
    "        k2q[k].add(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# update data\n",
    "selected_data = selected_data[selected_data.apply(lambda x: x['question_id'] in q2k, axis=1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# renumber the students\n",
    "s2n = {}\n",
    "cnt = 0\n",
    "for i, row in selected_data.iterrows():\n",
    "    if row.student_id not in s2n:\n",
    "        s2n[row.student_id] = cnt\n",
    "        cnt += 1\n",
    "selected_data.loc[:, 'student_id'] = selected_data.loc[:, 'student_id'].apply(lambda x: s2n[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# renumber the questions\n",
    "q2n = {}\n",
    "cnt = 0\n",
    "for i, row in selected_data.iterrows():\n",
    "    if row.question_id not in q2n:\n",
    "        q2n[row.question_id] = cnt\n",
    "        cnt += 1\n",
    "selected_data.loc[:, 'question_id'] = selected_data.loc[:, 'question_id'].apply(lambda x: q2n[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# renumber the knowledges\n",
    "k2n = {}\n",
    "cnt = 0\n",
    "for i, row in selected_data.iterrows():\n",
    "    for k in str(row.knowledge_id).split('_'):\n",
    "        if int(float(k)) not in k2n:\n",
    "            k2n[int(float(k))] = cnt\n",
    "            cnt += 1\n",
    "selected_data.loc[:, 'knowledge_id'] = selected_data.loc[:, 'knowledge_id'].apply(lambda x: '_'.join(map(lambda y: str(k2n[int(float(y))]), str(x).split('_'))))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total length: 101470\n",
      "Number of unique [student_id,question_id]: 71515\n",
      "Number of unique student_id: 1291\n",
      "Number of unique question_id: 1485\n",
      "Number of unique knowledge_id: 35\n",
      "Average #questions per knowledge: 59.4\n"
     ]
    }
   ],
   "source": [
    "stat_unique(selected_data, None)\n",
    "a=stat_unique(selected_data, ['student_id', 'question_id'])\n",
    "b=stat_unique(selected_data, 'student_id')\n",
    "c=stat_unique(selected_data, 'question_id')\n",
    "d=stat_unique(selected_data, 'knowledge_id')\n",
    "print('Average #questions per knowledge: {}'.format((len(q2k) / len(k2q))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_data.to_csv('selected_data.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save concept map\n",
    "q2k = {}\n",
    "table = selected_data.loc[:, ['question_id', 'knowledge_id']].drop_duplicates()\n",
    "for i, row in table.iterrows():\n",
    "    q = str(row['question_id'])\n",
    "    q2k[q] = list(map(int, str(row['knowledge_id']).split('_')))\n",
    "with open('concept_map.json', 'w') as f:\n",
    "    json.dump(q2k, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## parse data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_data(data):\n",
    "    \"\"\" \n",
    "\n",
    "    Args:\n",
    "        data: list of triplets (sid, qid, score)\n",
    "        \n",
    "    Returns:\n",
    "        student based datasets: defaultdict {sid: {qid: score}}\n",
    "        question based datasets: defaultdict {qid: {sid: score}}\n",
    "    \"\"\"\n",
    "    stu_data = defaultdict(lambda: defaultdict(dict))\n",
    "    ques_data = defaultdict(lambda: defaultdict(dict))\n",
    "    for i, row in data.iterrows():\n",
    "        sid = row.student_id\n",
    "        qid = row.question_id\n",
    "        correct = row.correct\n",
    "        stu_data[sid][qid] = correct\n",
    "        ques_data[qid][sid] = correct\n",
    "    return stu_data, ques_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "for i, row in selected_data.iterrows():\n",
    "    data.append([row.student_id, row.question_id, row.correct])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "stu_data, ques_data = parse_data(selected_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 0.2\n",
    "least_test_length=150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "51465\n",
      "50005\n"
     ]
    }
   ],
   "source": [
    "n_students = len(stu_data)\n",
    "if isinstance(test_size, float):\n",
    "    test_size = int(n_students * test_size)\n",
    "train_size = n_students - test_size\n",
    "assert(train_size > 0 and test_size > 0)\n",
    "\n",
    "students = list(range(n_students))\n",
    "random.shuffle(students)\n",
    "if least_test_length is not None:\n",
    "    student_lens = defaultdict(int)\n",
    "    for t in data:\n",
    "        student_lens[t[0]] += 1\n",
    "    students = [student for student in students\n",
    "                if student_lens[student] >= least_test_length]\n",
    "test_students = set(students[:test_size])\n",
    "\n",
    "train_data = [record for record in data if record[0] not in test_students]\n",
    "test_data = [record for record in data if record[0] in test_students]\n",
    "print(len(train_data))\n",
    "print(len(test_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def renumber_student_id(data):\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        data: list of triplets (sid, qid, score)\n",
    "    \n",
    "    Returns:\n",
    "        renumbered datasets: list of triplets (sid, qid, score)\n",
    "    \"\"\"\n",
    "    student_ids = sorted(set(t[0] for t in data))\n",
    "    renumber_map = {sid: i for i, sid in enumerate(student_ids)}\n",
    "    data = [(renumber_map[t[0]], t[1], t[2]) for t in data]\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = renumber_student_id(train_data)\n",
    "test_data = renumber_student_id(test_data)\n",
    "all_data = renumber_student_id(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train records length: 51465\n",
      "test records length: 50005\n",
      "all records length: 101470\n"
     ]
    }
   ],
   "source": [
    "print(f'train records length: {len(train_data)}')\n",
    "print(f'test records length: {len(test_data)}')\n",
    "print(f'all records length: {len(all_data)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## save data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_to_csv(data, path):\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        data: list of triplets (sid, qid, correct)\n",
    "        path: str representing saving path\n",
    "    \"\"\"\n",
    "    pd.DataFrame.from_records(sorted(data), columns=['student_id', 'question_id', 'correct']).to_csv(path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_to_csv(train_data, 'train_triples.csv')\n",
    "save_to_csv(test_data, 'test_triples.csv')\n",
    "save_to_csv(all_data, 'triples.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata = {\"num_students\": n_students, \n",
    "            \"num_questions\": c,\n",
    "            \"num_concepts\": d, \n",
    "            \"num_records\": len(all_data), \n",
    "            \"num_train_students\": n_students - len(test_students), \n",
    "            \"num_test_students\": len(test_students)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('metadata.json', 'w') as f:\n",
    "    json.dump(metadata, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
