{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "work_dir = '/work/dir'\n",
    "dataset = 'LF-AmazonTitles-131K'\n",
    "seq_len = 32\n",
    "dataset_dir = f'{work_dir}/Datasets/{dataset}'\n",
    "tok_dir = f'{dataset_dir}/bert-base-uncased-{seq_len}'\n",
    "\n",
    "aug_dataset_dir = f'{dataset_dir}-Aug'\n",
    "aug_tok_dir = f'{aug_dataset_dir}/bert-base-uncased-{seq_len}'\n",
    "os.makedirs(aug_dataset_dir, exist_ok=True)\n",
    "os.makedirs(aug_tok_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import scipy.sparse as sp\n",
    "import numpy as np\n",
    "import subprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def copy_file(src_fpth, dst_fpth):\n",
    "    if os.path.exists(src_fpth):\n",
    "        subprocess.run(['cp', src_fpth, dst_fpth])\n",
    "    else: # for filter files create a dummy file\n",
    "        print('Creaing dummy for ', dst_fpth)\n",
    "        subprocess.run(['touch', dst_fpth])\n",
    "        \n",
    "def copy_files(src_dir, dst_dir, files):\n",
    "    for file in files:\n",
    "        print(f'Copying {file}')\n",
    "        copy_file(f'{src_dir}/{file}', f'{dst_dir}/{file}')\n",
    "\n",
    "def write_mmap(np_arr, fpth):\n",
    "    np_arr_mmap = np.memmap(fpth, mode='w+', shape=np_arr.shape, dtype=np.int64)\n",
    "    np_arr_mmap[:] = np_arr[:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_X_Y = sp.load_npz(f'{dataset_dir}/trn_X_Y.npz')\n",
    "identity = sp.diags(np.ones(trn_X_Y.shape[1])).tocsr()\n",
    "aug_trn_X_Y = sp.vstack([trn_X_Y, identity]).tocsr()\n",
    "trn_doc_ii = np.memmap(f'{tok_dir}/trn_doc_input_ids.dat', shape=(trn_X_Y.shape[0], seq_len), dtype=np.int64, mode='r')\n",
    "trn_doc_am = np.memmap(f'{tok_dir}/trn_doc_attention_mask.dat', shape=(trn_X_Y.shape[0], seq_len), dtype=np.int64, mode='r')\n",
    "lbl_ii = np.memmap(f'{tok_dir}/lbl_input_ids.dat', shape=(trn_X_Y.shape[1], seq_len), dtype=np.int64, mode='r')\n",
    "lbl_am = np.memmap(f'{tok_dir}/lbl_attention_mask.dat', shape=(trn_X_Y.shape[1], seq_len), dtype=np.int64, mode='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_trn_doc_ii = np.concatenate([trn_doc_ii, lbl_ii], axis = 0)\n",
    "aug_trn_doc_am = np.concatenate([trn_doc_am, lbl_am], axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_files(dataset_dir, aug_dataset_dir, ['tst_X_Y.npz', 'trn_filter_labels.txt', 'tst_filter_labels.txt'])\n",
    "copy_files(tok_dir, aug_tok_dir, ['tst_doc_input_ids.dat', 'tst_doc_attention_mask.dat', 'lbl_input_ids.dat', 'lbl_attention_mask.dat'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sp.save_npz(f'{aug_dataset_dir}/trn_X_Y.npz', trn_X_Y)\n",
    "write_mmap(aug_trn_doc_ii, f'{aug_tok_dir}/trn_doc_input_ids.dat')\n",
    "write_mmap(aug_trn_doc_am, f'{aug_tok_dir}/trn_doc_attention_mask.dat')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
