{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab4f6d3-ca40-4899-ae4a-1e2cd57f8995",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import pickle\n",
    "from human_eval.data import HUMAN_EVAL, read_problems, stream_jsonl, write_jsonl\n",
    "import numpy as np\n",
    "import pickle\n",
    "from tqdm import trange\n",
    "from sklearn.model_selection import KFold, StratifiedKFold, StratifiedGroupKFold\n",
    "import xgboost as xgb\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.model_selection import cross_validate\n",
    "from sklearn.linear_model import LogisticRegression, Lasso\n",
    "from sklearn.preprocessing import normalize, StandardScaler, MinMaxScaler\n",
    "from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_curve\n",
    "from sklearn.model_selection import train_test_split\n",
    "from scipy.stats import spearmanr, kendalltau\n",
    "from matplotlib.pyplot import figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e60cd7ad-596f-43c4-9ddd-bbf735ba68c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TASKS = 164"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "817c60b8-ec6f-4d7e-bbd4-6d2e97c7c588",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_mtd(dgms):\n",
    "    dgm0 = dgms[0]\n",
    "    dgm1 = dgms[1]\n",
    "    \n",
    "    mtd0 = np.sum(dgm0[dgm0 < np.inf])\n",
    "    if dgm1.shape[0]:\n",
    "        mtd1 = np.sum(dgm1[:, 1] - dgm1[:, 0])\n",
    "    else:\n",
    "        mtd1 = 0\n",
    "    \n",
    "    return mtd0, mtd1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c0f7e28-625f-42a2-91e7-8fbd1be918cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "results_code = {}\n",
    "\n",
    "for i in trange(NUM_TASKS):\n",
    "    results1 = pickle.load(open('/he_results2/%d.pickle' % i, 'rb'))\n",
    "    results_code1 = pickle.load(open('/he_results2/%d_code.pickle' % i, 'rb'))\n",
    "    \n",
    "    for k, v in results1.items():\n",
    "        dgms_a, dgms_b = v[-2]['dgms'], v[-1]['dgms']\n",
    "        results1[k] = v[:4] + [calc_mtd(dgms_a), calc_mtd(dgms_b)]\n",
    "\n",
    "    results.update(results1)\n",
    "    results_code.update(results_code1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f757c28-4561-41bc-8442-c22ee06a1979",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pickle.dump((results, results_code), open('/he_results2/all.pickle', 'wb'))\n",
    "(results, results_code) = pickle.load(open('/he_results2/all.pickle', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d565f34f-3801-4e18-99ef-108bf76b0114",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "num_samples_per_task = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8415632e-67e9-4992-aaa7-614ed9d1016e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for task_num in range(NUM_TASKS):\n",
    "    task_id = 'HumanEval/%d' % task_num\n",
    "    for seed in range(5):\n",
    "        for i in range(num_samples_per_task):\n",
    "            samples.append(dict(task_id = task_id, completion = results_code[(task_id, seed)][0][i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee3c9d89-ebcf-4640-9973-76637a8b43af",
   "metadata": {},
   "outputs": [],
   "source": [
    "write_jsonl(\"samples_he.jsonl\", samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "426032d9-e687-464e-a20b-59f0ca465a3c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!evaluate_functional_correctness samples_he.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce71648b-6dbb-4c78-b81d-d13964cffdfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_stat = [[] for _ in range(NUM_TASKS)]\n",
    "\n",
    "for elem in stream_jsonl('samples_he.jsonl_results.jsonl'):\n",
    "    task_id = elem['task_id']\n",
    "    task_num = int(task_id.split('/')[-1])\n",
    "\n",
    "    run_stat[task_num].append(int(elem['passed']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d3455d-1dd9-4443-9667-9e490a39d93d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "run_stat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4b3eb0d-df42-455f-b02a-640e73468d27",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(run_stat), np.mean(run_stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba9b1a1d-288d-4226-9e7d-10190dd0ae40",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(run_stat)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e77a6596-d4ae-45f2-8a04-31b63738d834",
   "metadata": {},
   "source": [
    "### Random split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e9eeaa7-27db-45cc-9bc7-ad889c83ec65",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = []\n",
    "\n",
    "for elem in stream_jsonl('samples_he.jsonl_results.jsonl'):\n",
    "    task_id = elem['task_id']\n",
    "    task_num = int(task_id.split('/')[-1])\n",
    "\n",
    "    Y.append(int(elem['passed']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d096a04-91bc-460d-81e4-d2f3147205e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_probs():\n",
    "    all_probs = []\n",
    "\n",
    "    for task_num in range(NUM_TASKS):\n",
    "        features = []\n",
    "        task_id = 'HumanEval/%d' % task_num\n",
    "        for seed in range(5):\n",
    "            for seq in range(5):\n",
    "                prompt_len = results[(task_id, 0, 0, 0, 0)][0]\n",
    "                answer_len = results[(task_id, seed, seq, 0, 0)][1] - prompt_len\n",
    "                f_sample = [prompt_len, answer_len]\n",
    "                probs = results_code[(task_id, 0)][2][seq]            \n",
    "                probs = probs[prompt_len:prompt_len + answer_len]\n",
    "                        \n",
    "                all_probs.append(np.mean(np.log(probs)))\n",
    "    \n",
    "    return np.array(all_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0547d10e-fcae-4dcb-8528-3c6c3d09b923",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_probs = get_probs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f649124-e6f8-4eff-8a23-db3683e4e185",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_features():\n",
    "    all_features = []\n",
    "\n",
    "    for task_num in range(NUM_TASKS):\n",
    "        features = []\n",
    "        task_id = 'HumanEval/%d' % task_num\n",
    "        for seed in range(5):\n",
    "            for seq in range(num_samples_per_task):\n",
    "\n",
    "                prompt_len = results[(task_id, 0, 0, 0, 0)][0]\n",
    "                answer_len = results[(task_id, seed, seq, 0, 0)][1] - prompt_len\n",
    "                f_sample = [] \n",
    "                \n",
    "                for layer in range(32):\n",
    "                    for head in range(32):\n",
    "                        f = results[(task_id, seed, seq, layer, head)]\n",
    "                        f = [f[2]/ prompt_len, f[3]/ answer_len, f[4][0] / answer_len, f[4][1] / answer_len, f[5][0] / prompt_len , f[5][1] / prompt_len]\n",
    "                        f_sample.extend(f)\n",
    "                        \n",
    "                all_features.append(f_sample)\n",
    "    \n",
    "    return all_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f706bb71-fa65-4c88-a729-f882e8351109",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_names = []\n",
    "cnt = 0\n",
    "f = ['prompt_self_att', 'answer_self_att', 'mtd_a_h0', 'mtd_a_h1', 'mtd_b_h0', 'mtd_b_h1']\n",
    "        \n",
    "for layer in range(32):\n",
    "    for head in range(32):\n",
    "        for f1 in f:\n",
    "            f_names.append('%d_%s_%d_%d' % (cnt, f1, layer, head))\n",
    "            cnt += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e6f8fc-efe2-428a-b6f7-742e403b93f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(f_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12da6348-67b1-4f5a-b198-1edcfbc02aff",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = prepare_features()\n",
    "X = np.array(X)\n",
    "Y = np.array(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67657e54-8e9e-4caa-90e6-b8b31a69cdd3",
   "metadata": {},
   "source": [
    "### Split by task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4417b82e-d5e6-4306-ab95-3e4ee6d4225c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_quality_group(train_idx, test_idx):\n",
    "\n",
    "    X_train = []\n",
    "    X_test = []\n",
    "    y_train = []\n",
    "    y_test = []\n",
    "    probs_train = []\n",
    "    probs_test = []\n",
    "\n",
    "    test_info = []\n",
    "    test_task_nums = set()\n",
    "    \n",
    "    for i in range(X.shape[0]):\n",
    "        if i in train_idx:\n",
    "            X_train.append(X[i])\n",
    "            y_train.append(Y[i])\n",
    "            probs_train.append(all_probs[i])\n",
    "        else:\n",
    "            X_test.append(X[i])\n",
    "            y_test.append(Y[i])\n",
    "            test_info.append(i // 25)\n",
    "            probs_test.append(all_probs[i])\n",
    "    \n",
    "    X_train = np.array(X_train)\n",
    "    X_test = np.array(X_test)\n",
    "    #X_train = np.array(probs_train).reshape(len(probs_train), 1)\n",
    "    #X_test = np.array(probs_test).reshape(len(probs_test), 1)\n",
    "\n",
    "    print(X_train.shape)\n",
    "    print(X_test.shape)\n",
    "\n",
    "    clf = xgb.XGBClassifier(tree_method=\"hist\", max_bin = 64, n_estimators = 1000, eta = 0.1)\n",
    "    clf.fit(X_train, y_train)    \n",
    "    y_pred = clf.predict_proba(X_test)[:,1]\n",
    "    y_pred_class = clf.predict(X_test)\n",
    "\n",
    "    #\n",
    "    #\n",
    "    #\n",
    "    task_res = {task_num : [] for task_num in set(test_info)}\n",
    "\n",
    "    for i in range(len(test_info)):\n",
    "        task_num = test_info[i]\n",
    "        p = y_pred[i]\n",
    "        task_res[task_num].append((p, y_test[i]))\n",
    "\n",
    "    all_candidates = []\n",
    "\n",
    "    for task_num in task_res:\n",
    "        pred_list = task_res[task_num]\n",
    "        pred_list_sorted = sorted(pred_list, key = lambda x : -x[0])\n",
    "        #best_candidate = [x[1] for x in pred_list_sorted[0:1]]\n",
    "    \n",
    "        all_candidates.append(pred_list_sorted)\n",
    "\n",
    "    return roc_auc_score(y_test, y_pred), f1_score(y_test, y_pred_class), all_candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70a57f86-bc8d-4a8c-b301-54b5df2b2932",
   "metadata": {},
   "outputs": [],
   "source": [
    "groups = []\n",
    "\n",
    "for i in range(NUM_TASKS):\n",
    "    for j in range(25):\n",
    "        groups.append(i)\n",
    "\n",
    "kf = StratifiedGroupKFold(n_splits = 5, shuffle = True, random_state = 42)\n",
    "res_auc = []\n",
    "res_f1 = []\n",
    "all_task_res = []\n",
    "\n",
    "for train_idx, test_idx in kf.split(range(X.shape[0]), Y, groups):\n",
    "    #print(\"%s %s\" % (train_idx, test_idx))\n",
    "    auc, f1, task_res = calc_quality_group(train_idx, test_idx)\n",
    "    res_auc.append(auc)\n",
    "    res_f1.append(f1)\n",
    "    all_task_res.append(task_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e213f12-07d2-43c1-b319-24bcfb8d36aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(res_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "504e4785-a9d8-4bb5-9a72-67bb3665a229",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(res_auc), np.std(res_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8af50a1-2fff-4a6b-86f6-e5105efa045d",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(res_f1), np.std(res_f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d8adc90-2d55-4df5-b88e-2112fc28be28",
   "metadata": {},
   "outputs": [],
   "source": [
    "pass1 = 0\n",
    "pass10 = 0\n",
    "cnt = 0\n",
    "\n",
    "pass1_list = []\n",
    "pass1_fold = 0\n",
    "cnt_fold = 0\n",
    "\n",
    "for fold in all_task_res:\n",
    "\n",
    "    pass1_fold = 0\n",
    "    cnt_fold = 0\n",
    "    \n",
    "    for task_pred in fold:\n",
    "        pass1 += task_pred[0][1]\n",
    "        pass10 += max([x[1] for x in task_pred[0:10]])\n",
    "        #print(task_pred[0:5])\n",
    "        cnt += 1\n",
    "        \n",
    "        pass1_fold += task_pred[0][1]\n",
    "        cnt_fold += 1\n",
    "\n",
    "    pass1_list.append(pass1_fold / cnt_fold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1e0aeda-0515-42a4-a98d-0d3fc43c7ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(pass1_list), np.std(pass1_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe13eea1-7f68-439e-a496-6c174fab47af",
   "metadata": {},
   "outputs": [],
   "source": [
    "pass1/cnt, pass10/cnt, cnt"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
