{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5d901b6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "import math\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7473e21a",
   "metadata": {},
   "outputs": [],
   "source": [
    "testpath = \"./data/codenet-jsonl-processed\"\n",
    "plbart_gencode = \"./PLBART/generated-code\"\n",
    "\n",
    "PERC = 0.25"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aab5ae10",
   "metadata": {},
   "source": [
    "#### Filter original test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "147bc4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def getlen(x):\n",
    "    res = 1\n",
    "    for v in x.values():\n",
    "        res += len(v)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e4c5c52c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eca204e2cb88455e9da06b9b97af2da1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/90 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for langpair in tqdm(sorted(os.listdir(testpath))):\n",
    "    \n",
    "    if langpair.startswith('.'):\n",
    "        continue\n",
    "        \n",
    "    fpath = os.path.join(testpath, langpair, 'test.jsonl')\n",
    "    fpath_org = os.path.join(testpath, langpair, 'test-original.jsonl')\n",
    "    \n",
    "    langpair_plbart = langpair.replace('-', '_')\n",
    "    fpath_plbart = os.path.join(plbart_gencode, langpair_plbart, 'code.jsonl')\n",
    "    \n",
    "    sampled = {}\n",
    "    \n",
    "    with open(fpath, 'r') as f:\n",
    "        testdata = f.readlines()\n",
    "        \n",
    "    with open(fpath_org, 'w') as f:\n",
    "        f.writelines(testdata)\n",
    "    \n",
    "    for line in testdata:\n",
    "        line = json.loads(line)\n",
    "        \n",
    "        if line['prob'] not in sampled:\n",
    "            sampled[line['prob']] = []\n",
    "        sampled[line['prob']].append(line['submission_id'])\n",
    "    \n",
    "    # Filter original test code\n",
    "    subsampled = {}\n",
    "    \n",
    "    for prob, subids in sampled.items():\n",
    "        n =  max(1, math.ceil(PERC * len(subids)))\n",
    "        n = int(n)\n",
    "        probsamples = np.random.choice(subids, n, replace=False)\n",
    "        subsampled[prob] = list(probsamples)\n",
    "    \n",
    "    \n",
    "    with open(fpath, 'w') as f:\n",
    "        for line in testdata:\n",
    "            line = json.loads(line)\n",
    "\n",
    "            prob = line['prob']\n",
    "            subid = line['submission_id']\n",
    "\n",
    "            if prob not in subsampled:\n",
    "                raise Exception(f'{prob} not found in subsampled')\n",
    "\n",
    "            if subid not in subsampled[prob]:\n",
    "                continue\n",
    "\n",
    "            f.write(json.dumps(line) + '\\n')\n",
    "            \n",
    "    with open(fpath_plbart, 'r') as f:\n",
    "        gencode = f.readlines()\n",
    "        \n",
    "    with open(fpath_plbart, 'w') as f:\n",
    "        for line in gencode:\n",
    "            line = json.loads(line)\n",
    "            \n",
    "            prob = line['prob']\n",
    "            subid = line['submission_id']\n",
    "            \n",
    "            if prob not in subsampled:\n",
    "                raise Exception(f'{prob} not in plbart subsampled')\n",
    "            if subid not in subsampled[prob]:\n",
    "                continue\n",
    "            \n",
    "            f.write(json.dumps(line) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7158246a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Scala-Ruby'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "langpair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9682f868",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "671"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "getlen(subsampled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0b2244a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
