{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab4f6d3-ca40-4899-ae4a-1e2cd57f8995",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import pickle\n",
    "from human_eval.data import HUMAN_EVAL, read_problems, stream_jsonl, write_jsonl\n",
    "import numpy as np\n",
    "import pickle\n",
    "from tqdm import trange\n",
    "from sklearn.model_selection import KFold, StratifiedGroupKFold\n",
    "import xgboost as xgb\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.model_selection import cross_validate\n",
    "from sklearn.linear_model import LogisticRegression, Lasso\n",
    "from sklearn.preprocessing import normalize, StandardScaler, MinMaxScaler\n",
    "from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_recall_curve\n",
    "from sklearn.model_selection import train_test_split\n",
    "from scipy.stats import spearmanr, kendalltau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e60cd7ad-596f-43c4-9ddd-bbf735ba68c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TASKS = 510 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9d093b-a3e7-40ff-b1e2-4bdcb58a16b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_mtd(dgms):\n",
    "    dgm0 = dgms[0]\n",
    "    dgm1 = dgms[1]\n",
    "    \n",
    "    mtd0 = np.sum(dgm0[dgm0 < np.inf])\n",
    "    if dgm1.shape[0]:\n",
    "        mtd1 = np.sum(dgm1[:, 1] - dgm1[:, 0])\n",
    "    else:\n",
    "        mtd1 = 0\n",
    "    \n",
    "    return mtd0, mtd1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c88a3eb-ccab-4cfb-a0d2-80fa9129faa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#(results, results_code) = pickle.load(open('/mbpp_results2/all.pickle', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01efaf23-74f2-492c-a541-2d92b0aae909",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump((results, results_code), open('/mbpp_results2/all.pickle', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad1bec1a-42fe-4463-9a12-0b3dcda044a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "results_code = {}\n",
    "\n",
    "for i in trange(10, 510):\n",
    "    results1, results_code1 = pickle.load(open('/mbpp_results2/%d.pickle' % i, 'rb')), pickle.load(open('/mbpp_results2/%d_code.pickle' % i, 'rb'))\n",
    "\n",
    "    for k, v in results1.items():\n",
    "        dgms_a, dgms_b = v[-2]['dgms'], v[-1]['dgms']\n",
    "        results1[k] = v[:4] + [calc_mtd(dgms_a), calc_mtd(dgms_b)]\n",
    "\n",
    "    results.update(results1)\n",
    "    results_code.update(results_code1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8415632e-67e9-4992-aaa7-614ed9d1016e",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "num_samples_per_task = 1\n",
    "\n",
    "for task_num in trange(10, 510):        \n",
    "    task_id = 'MBPP/%d' % task_num\n",
    "    samples.append(dict(task_id = task_id, completions = results_code[(task_id, 0)][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee3c9d89-ebcf-4640-9973-76637a8b43af",
   "metadata": {},
   "outputs": [],
   "source": [
    "write_jsonl(\"mbpp_samples.jsonl\", samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "905d895b-0d74-4119-88e4-27374fbf9ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python evaluate_mbpp.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce71648b-6dbb-4c78-b81d-d13964cffdfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_stat = [[] for _ in range(NUM_TASKS)]\n",
    "\n",
    "for elem in stream_jsonl('mbpp_evaluation.jsonl'):\n",
    "    task_id = elem['task_id']\n",
    "    task_num = int(task_id.split('/')[-1])\n",
    "\n",
    "    run_stat[task_num].append([int(x[1]['passed']) for x in elem['results']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d3455d-1dd9-4443-9667-9e490a39d93d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "run_stat"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e77a6596-d4ae-45f2-8a04-31b63738d834",
   "metadata": {},
   "source": [
    "### Random split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5efbd3-b640-4fe0-a31a-bf505d2d2365",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = []\n",
    "\n",
    "for elem in stream_jsonl('mbpp_evaluation.jsonl'):\n",
    "    task_id = elem['task_id']\n",
    "    task_num = int(task_id.split('/')[-1])\n",
    "\n",
    "    Y.extend([int(x[1]['passed']) for x in elem['results']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf7faf6-cb40-4e0a-962a-d5173b605b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_probs():\n",
    "    all_probs = []\n",
    "\n",
    "    for task_num in range(10, 510):\n",
    "        features = []\n",
    "        task_id = 'MBPP/%d' % task_num\n",
    "        for seed in range(1):\n",
    "            for seq in range(5):\n",
    "\n",
    "                prompt_len = results[(task_id, 0, 0, 0, 0)][0]\n",
    "                answer_len = results[(task_id, seed, seq, 0, 0)][1] - prompt_len\n",
    "                probs = results_code[(task_id, 0)][2][seq]\n",
    "                probs = probs[prompt_len:prompt_len + answer_len]\n",
    "                        \n",
    "                all_probs.append(np.mean(np.log(probs)))\n",
    "    \n",
    "    return np.array(all_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce91b88e-15df-476d-8f0c-0c3a2e3254be",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_probs = get_probs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f649124-e6f8-4eff-8a23-db3683e4e185",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_features():\n",
    "    all_features = []\n",
    "\n",
    "    for task_num in range(10, 510):\n",
    "        features = []\n",
    "        task_id = 'MBPP/%d' % task_num\n",
    "        for seed in range(1):\n",
    "            for seq in range(5):\n",
    "\n",
    "                prompt_len = results[(task_id, 0, 0, 0, 0)][0]\n",
    "                answer_len = results[(task_id, seed, seq, 0, 0)][1] - prompt_len\n",
    "                f_sample = []\n",
    "                                \n",
    "                for layer in range(32):\n",
    "                    for head in range(32):\n",
    "                        f = results[(task_id, seed, seq, layer, head)]\n",
    "                        f = [f[2] / prompt_len, f[3] / answer_len, f[4][0] / answer_len, f[4][1] / answer_len, f[5][0] / prompt_len, f[5][1] / prompt_len]\n",
    "                        f_sample.extend(f)\n",
    "                        \n",
    "                all_features.append(f_sample)\n",
    "    \n",
    "    return all_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f706bb71-fa65-4c88-a729-f882e8351109",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_names = []\n",
    "cnt = 0\n",
    "\n",
    "f = ['prompt_self_att', 'answer_self_att', 'mtd_a_h0', 'mtd_a_h1', 'mtd_b_h0', 'mtd_b_h1']\n",
    "        \n",
    "for layer in range(32):\n",
    "    for head in range(32):\n",
    "        for f1 in f:\n",
    "            f_names.append('%d_%s_%d_%d' % (cnt, f1, layer, head))\n",
    "            cnt += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e6f8fc-efe2-428a-b6f7-742e403b93f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(f_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12da6348-67b1-4f5a-b198-1edcfbc02aff",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = prepare_features()\n",
    "X = np.array(X)\n",
    "Y = np.array(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6790ef4-d858-49b1-9efa-5b671820ed9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X.shape, len(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67657e54-8e9e-4caa-90e6-b8b31a69cdd3",
   "metadata": {},
   "source": [
    "### Split by task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7fa3b2-40fe-4ffa-8c66-6b3f1e776298",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_quality_group(train_idx, test_idx):\n",
    "\n",
    "    X_train = []\n",
    "    X_test = []\n",
    "    y_train = []\n",
    "    y_test = []\n",
    "    probs_train = []\n",
    "    probs_test = []\n",
    "\n",
    "    test_info = []\n",
    "    test_task_nums = set()\n",
    "    \n",
    "    for i in range(X.shape[0]):\n",
    "        if i in train_idx:\n",
    "            X_train.append(X[i])\n",
    "            y_train.append(Y[i])\n",
    "            probs_train.append(all_probs[i])\n",
    "        else:\n",
    "            X_test.append(X[i])\n",
    "            y_test.append(Y[i])\n",
    "            test_info.append(i // 5)\n",
    "            probs_test.append(all_probs[i])\n",
    "    \n",
    "    X_train = np.array(X_train)\n",
    "    X_test = np.array(X_test)\n",
    "    #X_train = np.array(probs_train).reshape(len(probs_train), 1)\n",
    "    #X_test = np.array(probs_test).reshape(len(probs_test), 1)\n",
    "    \n",
    "    print(X_train.shape)\n",
    "    print(X_test.shape)\n",
    "\n",
    "    clf = xgb.XGBClassifier(tree_method=\"hist\", max_bin = 64, n_estimators = 1000, eta = 0.1)\n",
    "    clf.fit(X_train, y_train)    \n",
    "    y_pred = clf.predict_proba(X_test)[:,1]\n",
    "    y_pred_class = clf.predict(X_test)\n",
    "\n",
    "    #y_pred = probs_test\n",
    "    #y_pred_class = [int(x > 0.5) for x in probs_test]\n",
    "\n",
    "    #\n",
    "    #\n",
    "    #\n",
    "    task_res = {task_num : [] for task_num in set(test_info)}\n",
    "\n",
    "    for i in range(len(test_info)):\n",
    "        task_num = test_info[i]\n",
    "        p = y_pred[i]\n",
    "        task_res[task_num].append((p, y_test[i]))\n",
    "\n",
    "    all_candidates = []\n",
    "\n",
    "    for task_num in task_res:\n",
    "        pred_list = task_res[task_num]\n",
    "        pred_list_sorted = sorted(pred_list, key = lambda x : -x[0])\n",
    "        #best_candidate = [x[1] for x in pred_list_sorted[0:1]]\n",
    "    \n",
    "        all_candidates.append(pred_list_sorted)\n",
    "\n",
    "    #return accuracy_score(y_test, y_pred > 0.5), all_candidates\n",
    "    return roc_auc_score(y_test, y_pred), f1_score(y_test, y_pred_class), all_candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0288f158-d3e3-4ba8-af83-c65509a71b11",
   "metadata": {},
   "outputs": [],
   "source": [
    "groups = []\n",
    "\n",
    "for i in range(500):\n",
    "    for j in range(5):\n",
    "        groups.append(i)\n",
    "\n",
    "kf = StratifiedGroupKFold(n_splits = 5, shuffle = True, random_state = 42)\n",
    "res_auc = []\n",
    "res_f1 = []\n",
    "all_task_res = []\n",
    "\n",
    "for train_idx, test_idx in kf.split(range(X.shape[0]), Y, groups):\n",
    "    #print(\"%s %s\" % (train_idx, test_idx))\n",
    "    auc, f1, task_res = calc_quality_group(train_idx, test_idx)\n",
    "    res_auc.append(auc)\n",
    "    res_f1.append(f1)\n",
    "    all_task_res.append(task_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70aa8c90-66f5-48b1-8c81-72407699d2ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(res_auc), np.std(res_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fad624b-2c6f-4191-a87a-c0778b796a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(res_f1), np.std(res_f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7835bb6-e35f-487c-a7f3-62a8847979d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "pass1 = 0\n",
    "pass10 = 0\n",
    "cnt = 0\n",
    "\n",
    "pass1_list = []\n",
    "pass1_fold = 0\n",
    "cnt_fold = 0\n",
    "\n",
    "for fold in all_task_res:\n",
    "\n",
    "    pass1_fold = 0\n",
    "    cnt_fold = 0\n",
    "    \n",
    "    for task_pred in fold:\n",
    "        pass1 += task_pred[0][1]\n",
    "        pass10 += max([x[1] for x in task_pred[0:10]])\n",
    "        #print(task_pred[0:5])\n",
    "        cnt += 1\n",
    "        \n",
    "        pass1_fold += task_pred[0][1]\n",
    "        cnt_fold += 1\n",
    "\n",
    "    pass1_list.append(pass1_fold / cnt_fold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fc64f73-a02d-41a7-b365-e7375cf6750e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(pass1_list), np.std(pass1_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f96cc520-2e6c-4aa2-90af-97e7821070d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pass1/cnt, pass10/cnt"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
