{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import itertools\n",
    "from matplotlib import pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def generate_data(generator, task_name, path='data', train_size=10_000, val_size=1_000, test_size=2_000, batch_size=32):\n",
    "    Xs, ys, _, _ = next(generator)\n",
    "    total_size = train_size + test_size + val_size\n",
    "    \n",
    "    try:\n",
    "        while(len(Xs) < total_size):\n",
    "            print(len(Xs))\n",
    "            X, y, _, _ = next(generator)\n",
    "            # Xs += list(X)\n",
    "            # ys += list(y)\n",
    "            Xs = np.vstack((Xs, X))\n",
    "            ys = np.vstack((ys, y))\n",
    "            if len(Xs) > total_size * 2:\n",
    "                print('length achieved')\n",
    "                Xs, ys = np.unique(Xs, axis=0), np.unique(ys, axis=0)\n",
    "    except(KeyboardInterrupt):\n",
    "        print(\"Interrupted\")\n",
    "        Xs, ys = np.unique(Xs, axis=0), np.unique(ys, axis=0)\n",
    "\n",
    "    Xs = np.vstack(Xs)\n",
    "    ys = np.vstack(ys)\n",
    "    \n",
    "    print(Xs.shape, ys.shape)\n",
    "    # _, inds = np.unique(Xs, axis=0, return_index=True)\n",
    "    inds = np.random.permutation(range(len(Xs)))\n",
    "    \n",
    "    Xs = Xs[inds][:total_size]\n",
    "    ys = ys[inds][:total_size]\n",
    "    \n",
    "    np.save(f'{path}/{task_name}_train_X.npy', Xs[:train_size] )\n",
    "    np.save(f'{path}/{task_name}_train_y.npy', ys[:train_size] )\n",
    "\n",
    "    np.save(f'{path}/{task_name}_val_X.npy', Xs[train_size:train_size+val_size] )\n",
    "    np.save(f'{path}/{task_name}_val_y.npy', ys[train_size:train_size+val_size] )\n",
    "\n",
    "    np.save(f'{path}/{task_name}_test_X.npy', Xs[train_size+val_size:train_size+val_size+test_size] )\n",
    "    np.save(f'{path}/{task_name}_test_y.npy', ys[train_size+val_size:train_size+val_size+test_size] )\n",
    "    # return inds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class copy_generator:\n",
    "    def __init__(self, seq_len, batch_size, num_tokens):\n",
    "        self.src_mask = torch.ones(batch_size, seq_len).bool()\n",
    "        self.tgt_mask = torch.ones(batch_size, 2 * seq_len + 1).bool()\n",
    "        \n",
    "        self.enc_seq_len = seq_len\n",
    "        self.dec_seq_len = 2 * seq_len\n",
    "        self.batch_size = batch_size\n",
    "        self.num_tokens = num_tokens\n",
    "    \n",
    "    def __next__(self):\n",
    "        X = np.zeros([self.batch_size, self.enc_seq_len]).astype(int)\n",
    "        y = np.zeros([self.batch_size, self.dec_seq_len+1]).astype(int)\n",
    "        y[:, 0] = 1\n",
    "        for i in range(self.batch_size):\n",
    "            sequence_length = self.enc_seq_len\n",
    "            random_sequence = np.random.randint(2, self.num_tokens, sequence_length)\n",
    "            \n",
    "            X[i, :sequence_length] = random_sequence\n",
    "            y[i, 1: 2 * sequence_length + 1] = np.concatenate([random_sequence] * 2)\n",
    "\n",
    "        return X, y, self.src_mask, self.tgt_mask        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X, y, _, _, = next(gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n",
      "20000\n",
      "30000\n",
      "40000\n",
      "50000\n",
      "60000\n",
      "70000\n",
      "80000\n",
      "90000\n",
      "100000\n",
      "110000\n",
      "120000\n",
      "(130000, 48) (130000, 97)\n"
     ]
    }
   ],
   "source": [
    "# SEQ_LEN = 24\n",
    "\n",
    "# task_name = f'copy{SEQ_LEN}'\n",
    "# BATCH_SIZE = 10_000\n",
    "# NUM_TOKENS = 10\n",
    "\n",
    "# train_size = 100_000\n",
    "# val_size = 10_000\n",
    "# test_size = 20_000\n",
    "\n",
    "# path = f'../synthetic/data{SEQ_LEN}'\n",
    "# os.system(f'mkdir {path}')\n",
    "\n",
    "# gen = copy_generator(seq_len=SEQ_LEN, batch_size=BATCH_SIZE, num_tokens=NUM_TOKENS)\n",
    "# generate_data(gen, task_name=task_name, path=path, train_size=train_size, val_size=val_size, test_size=test_size, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reverse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class reverse_generator:\n",
    "    def __init__(self, seq_len, batch_size, num_tokens):\n",
    "        self.src_mask = torch.ones(batch_size, seq_len).bool()\n",
    "        self.tgt_mask = torch.ones(batch_size, seq_len + 1).bool()\n",
    "        \n",
    "        self.enc_seq_len = seq_len\n",
    "        self.dec_seq_len = seq_len\n",
    "        self.batch_size = batch_size\n",
    "        self.num_tokens = num_tokens\n",
    "    \n",
    "    def __next__(self):\n",
    "        X = np.zeros([self.batch_size, self.enc_seq_len]).astype(int)\n",
    "        y = np.zeros([self.batch_size, self.dec_seq_len+1]).astype(int)\n",
    "        y[:, 0] = 1\n",
    "        for i in range(self.batch_size):\n",
    "            sequence_length = self.enc_seq_len\n",
    "            random_sequence = np.random.randint(2, self.num_tokens, sequence_length)\n",
    "            \n",
    "            X[i, :sequence_length] = random_sequence\n",
    "            y[i, 1: 2 * sequence_length + 1] = random_sequence[::-1]\n",
    "\n",
    "        return X, y, self.src_mask, self.tgt_mask        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n",
      "20000\n",
      "30000\n",
      "40000\n",
      "50000\n",
      "60000\n",
      "70000\n",
      "80000\n",
      "90000\n",
      "100000\n",
      "110000\n",
      "120000\n",
      "(130000, 240) (130000, 241)\n"
     ]
    }
   ],
   "source": [
    "# SEQ_LEN = 240\n",
    "\n",
    "# task_name = f'reverse{SEQ_LEN}'\n",
    "# BATCH_SIZE = 10000\n",
    "# NUM_TOKENS = 10\n",
    "\n",
    "# train_size = 100_000\n",
    "# val_size = 10_000\n",
    "# test_size = 20_000\n",
    "\n",
    "# path = f'../synthetic/data{SEQ_LEN}'\n",
    "# os.system(f'mkdir {path}')\n",
    "\n",
    "# gen = reverse_generator(seq_len=SEQ_LEN, batch_size=BATCH_SIZE, num_tokens=NUM_TOKENS)\n",
    "# generate_data(gen, task_name=task_name, path=path, train_size=train_size, val_size=val_size, test_size=test_size, batch_size=BATCH_SIZE)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "15ebdd31b1273fe4d2b1fe1822219a570cf61693f7cab545dbe286c10cf9691f"
  },
  "kernelspec": {
   "display_name": "Python 3.8.0 ('dpenv')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
