{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0b53c94",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open(\"\", \"r\") as f:\n",
    "    sft_data = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7a1baf56",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "greedy_sample_probs = []\n",
    "samples_probs = []\n",
    "\n",
    "for instance in sft_data:\n",
    "    greedy_sample_probs.append(instance['greedy_sample_prob'])\n",
    "    samples_probs.append(instance['samples_prob'])\n",
    "\n",
    "def get_logprob(list_dict):\n",
    "    logprobs = []\n",
    "    for i in range(1, len(list_dict)):\n",
    "        token_dict = list_dict[i]\n",
    "        logprobs.append(token_dict['logprob'])\n",
    "    \n",
    "    return np.array(logprobs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87259a39",
   "metadata": {},
   "outputs": [],
   "source": [
    "sft_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d63df49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.7092)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "Labels = [instance['label'] for instance in sft_data]\n",
    "\n",
    "min_samples_probs_temp1 = []\n",
    "\n",
    "\n",
    "for idx in range(len(samples_probs)):\n",
    "    idx_samples_probs_temp1 = samples_probs[idx]\n",
    "    min_samples_prob_temp1_list = [np.sum(sorted(get_logprob(lst))[:20]) / len(lst) for lst in idx_samples_probs_temp1]\n",
    "    min_samples_probs_temp1.append(np.std(min_samples_prob_temp1_list))\n",
    "\n",
    "\n",
    "roc_auc_score(Labels, np.array(min_samples_probs_temp1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfaff118",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liangrenzhao/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "500it [00:07, 68.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy = 0.562\n",
      "Precision = 0.5714285714285714\n",
      "Recall = 0.496\n",
      "F1Score = 0.5310492505353319\n",
      "AUC = 0.585784\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def tokenize_code(sample, tokenizer, length=512):\n",
    "    return tokenizer.encode(sample)[:length] if length else tokenizer.encode(sample)\n",
    "\n",
    "def strip_code(sample):\n",
    "    return sample.strip().split('\\n\\n\\n')[0] if '\\n\\n\\n' in sample else sample.strip().split('```')[0]\n",
    "\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score, roc_auc_score\n",
    "from Levenshtein import distance as levenshtein_distance\n",
    "\n",
    "def evaluate_classification(y_true, y_pred, y_pred_prob=None):\n",
    "    metrics = {\n",
    "        'Precision': precision_score(y_true, y_pred),\n",
    "        'Recall': recall_score(y_true, y_pred),\n",
    "        'Accuracy': accuracy_score(y_true, y_pred),\n",
    "        'F1 Score': f1_score(y_true, y_pred)\n",
    "    }\n",
    "    \n",
    "    if y_pred_prob is not None:\n",
    "        metrics['AUC'] = roc_auc_score(y_true, y_pred_prob)\n",
    "    \n",
    "    return metrics\n",
    "\n",
    "\n",
    "def get_edit_distance_distribution_star(samples, gready_sample, tokenizer, length = 512):\n",
    "    gready_sample = strip_code(gready_sample)\n",
    "    gs = tokenize_code(gready_sample, tokenizer, length)\n",
    "    num = []\n",
    "    max_length = len(gs)\n",
    "    for sample in samples:\n",
    "        sample = strip_code(sample)\n",
    "        s = tokenize_code(sample, tokenizer, length)\n",
    "        num.append(levenshtein_distance(gs, s))\n",
    "        max_length = max(max_length, len(s))\n",
    "    return num, max_length\n",
    "\n",
    "def calculate_ratio(numbers, alpha=0.05):\n",
    "    count = sum(1 for num in numbers if num <= alpha)\n",
    "    total = len(numbers)\n",
    "    ratio = count / total if total > 0 else 0\n",
    "    return ratio\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "model_path = \"\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "\n",
    "alpha = 0.05\n",
    "xi = 0.01\n",
    "Results=[]\n",
    "Labels = []\n",
    "stds = []\n",
    "for i, task in tqdm(enumerate(sft_data)):  \n",
    "    dist, ml = get_edit_distance_distribution_star(task['samples'], task['greedy_sample'], tokenizer)\n",
    "    dist = np.array(dist)\n",
    "    stds.append(np.std(dist))\n",
    "    peak = calculate_ratio(dist, alpha*ml) \n",
    "    Results.append(peak)\n",
    "    Labels.append(task['label'])\n",
    "\n",
    "metric = evaluate_classification(Labels, [i>xi for i in Results], Results)\n",
    "\n",
    "print(f'Accuracy = {metric[\"Accuracy\"]}')\n",
    "print(f'Precision = {metric[\"Precision\"]}')\n",
    "print(f'Recall = {metric[\"Recall\"]}')\n",
    "print(f'F1Score = {metric[\"F1 Score\"]}')\n",
    "print(f'AUC = {metric[\"AUC\"]}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CDD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
