{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get a list of problem sets from matched dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "datapath = \"./data/codenet-jsonl-processed\"\n",
    "outdir = \"./random-data/codenet-jsonl-processed\"\n",
    "codenetdir = \"../Project_CodeNet\"\n",
    "\n",
    "os.makedirs(outdir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "90"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "langpairs = [x for x in os.listdir(datapath) if not x.startswith('.')]\n",
    "len(langpairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/graphcodebert-base\")\n",
    "\n",
    "def tokenize(code):\n",
    "    RM = ['Ċ', 'Ġ']\n",
    "    tokens = [x for x in tokenizer.tokenize(code) if x not in RM]\n",
    "    return len(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4994d23f19c94a7fa586a127b999fb3e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/90 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "probdict = {}\n",
    "\n",
    "for langs in tqdm(langpairs):\n",
    "    train_fpath = os.path.join(datapath, langs, 'train.jsonl')\n",
    "    val_fpath = os.path.join(datapath, langs, 'val.jsonl')\n",
    "    \n",
    "    cnt = 0\n",
    "    probdict[langs] = {'probid':set(), 'n': None}\n",
    "    \n",
    "    with open(train_fpath, 'r') as f:\n",
    "        for line in f:\n",
    "            line = json.loads(line)\n",
    "            probdict[langs]['probid'].add(line['prob'])\n",
    "            cnt += 1\n",
    "        \n",
    "        probdict[langs]['n'] = cnt\n",
    "\n",
    "    with open(val_fpath, 'r') as f:\n",
    "        for line in f:\n",
    "            line = json.loads(line)\n",
    "            probdict[langs]['probid'].add(line['prob'])\n",
    "            cnt += 1\n",
    "        \n",
    "        probdict[langs]['n'] = cnt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read codenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_length(row):\n",
    "    \n",
    "    THRES = 512\n",
    "    prob = row['problem_id']\n",
    "    lang = row['language']\n",
    "    subid = row['submission_id']\n",
    "    ext = row['filename_ext']\n",
    "    \n",
    "    path= os.path.join(codenetdir, 'data', prob, lang, f'{subid}.{ext}')\n",
    "    with open(path, 'r') as f:\n",
    "        code = f.read().strip()\n",
    "    \n",
    "    return tokenize(code) <= THRES\n",
    "\n",
    "def read_codenet_metadata(problem, lang1, lang2):\n",
    "    \n",
    "    lang1_subs, lang2_subs = [], []\n",
    "    fpath = os.path.join(codenetdir, 'metadata', f'{problem}.csv')\n",
    "    metadata = pd.read_csv(fpath)\n",
    "    \n",
    "    metadata = metadata[ (metadata['language'].isin([lang1, lang2])) & (metadata['status'] == 'Accepted') ]\n",
    "    \n",
    "    for i, row in metadata.iterrows():\n",
    "        \n",
    "        if not check_length(row):\n",
    "            continue\n",
    "        \n",
    "        if row['language'] == lang1:\n",
    "            lang1_subs.append(row['submission_id'])\n",
    "        else:\n",
    "            lang2_subs.append(row['submission_id'])\n",
    "    \n",
    "    lang1_subs = np.random.permutation(lang1_subs)\n",
    "    lang2_subs = np.random.permutation(lang2_subs)\n",
    "    n = min(len(lang1_subs), len(lang2_subs))\n",
    "    \n",
    "    lang1_subs = list(lang1_subs[:n])\n",
    "    lang2_subs = list(lang2_subs[:n])\n",
    "    \n",
    "    return lang1_subs, lang2_subs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df3683b33e514e47ac9102d582ba9dc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/90 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/791 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/688 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/822 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/714 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/312 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/533 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/912 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/1111 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/329 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/584 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/416 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/342 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/1007 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/973 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/737 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/606 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/535 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c73b9afdfc8f401fb12109e7129ad402",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Problems:   0%|          | 0/1236 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-fb7885746649>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mprob\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Problems\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleave\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m         \u001b[0mlang1_subs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlang2_subs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_codenet_metadata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprob\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlang1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlang2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m         \u001b[0mallsubs_lang1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlang1_subs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m         \u001b[0mallsubs_lang2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlang2_subs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-6-c4dd3b486190>\u001b[0m in \u001b[0;36mread_codenet_metadata\u001b[0;34m(problem, lang1, lang2)\u001b[0m\n\u001b[1;32m     23\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_length\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m             \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-6-c4dd3b486190>\u001b[0m in \u001b[0;36mcheck_length\u001b[0;34m(row)\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0mcode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mTHRES\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mread_codenet_metadata\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproblem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlang1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlang2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-4-08ede14e0b3d>\u001b[0m in \u001b[0;36mtokenize\u001b[0;34m(code)\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mRM\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'Ċ'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Ġ'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mRM\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/transformers/tokenization_utils_fast.py\u001b[0m in \u001b[0;36mtokenize\u001b[0;34m(self, text, pair, add_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m    300\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    301\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mtokenize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpair\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madd_special_tokens\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode_plus\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_pair\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpair\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madd_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0madd_special_tokens\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokens\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    304\u001b[0m     def set_truncation_and_padding(\n",
      "\u001b[0;32m~/anaconda3/envs/wmd/lib/python3.8/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mtokens\u001b[0;34m(self, batch_index)\u001b[0m\n\u001b[1;32m    287\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_encodings\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    288\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"tokens() is not available when using Python-based tokenizers\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_encodings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbatch_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokens\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    290\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    291\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0msequence_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_index\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "try:\n",
    "    with open(os.path.join(outdir, 'codepairs.pkl'), 'rb') as f:\n",
    "        codepairs = pickle.load(f)\n",
    "except:\n",
    "    codepairs = {}\n",
    "\n",
    "for langs, vals in tqdm(probdict.items()):\n",
    "    \n",
    "    if langs in codepairs:\n",
    "        continue\n",
    "    \n",
    "    lang_codepairs = []\n",
    "    lang1, lang2 = langs.split('-')\n",
    "    probs = vals['probid']\n",
    "    n = vals['n']\n",
    "    \n",
    "    allsubs_lang1, allsubs_lang2 = [], []\n",
    "    \n",
    "    for prob in tqdm(probs, desc=\"Problems\", leave=False):\n",
    "        lang1_subs, lang2_subs = read_codenet_metadata(prob, lang1, lang2)\n",
    "        allsubs_lang1.append(lang1_subs)\n",
    "        allsubs_lang2.append(lang2_subs)\n",
    "        \n",
    "    \n",
    "    idx = 0\n",
    "    while True:\n",
    "        for i, prob in enumerate(probs):\n",
    "            if len(allsubs_lang1[i]) <= idx or len(allsubs_lang2[i]) <= idx:\n",
    "                continue\n",
    "            \n",
    "            if len(lang_codepairs) >= n:\n",
    "                break\n",
    "            \n",
    "            lang_codepairs.append( (prob, allsubs_lang1[i][idx], allsubs_lang2[i][idx]) )\n",
    "        \n",
    "        idx += 1\n",
    "        \n",
    "        if len(lang_codepairs) >= n:\n",
    "            break\n",
    "    \n",
    "    codepairs[langs] = lang_codepairs\n",
    "    \n",
    "    \n",
    "    with open(os.path.join(outdir, 'codepairs.pkl'), 'wb') as f:\n",
    "        pickle.dump(codepairs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
