{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate Random Object Sorting Tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sorting_dataset(vocab_size, dim, seqs_length, n_seqs):\n",
    "\n",
    "    # generate random features for each object\n",
    "    objects = np.random.normal(size=(vocab_size, dim))\n",
    "\n",
    "    # generate random permutations of length `seqs_length` out of `vocab_size`\n",
    "    seqs = np.array([np.random.choice(range(vocab_size), size=seqs_length, replace=False) for _ in range(n_seqs)])\n",
    "    \n",
    "    # remove duplicate seqs (although very unlikely)\n",
    "    _, unique_seq_idxs = np.unique(seqs, axis=0, return_inverse=True)\n",
    "    seqs = seqs[unique_seq_idxs]\n",
    "\n",
    "    # create object sequences\n",
    "    object_seqs = objects[seqs]\n",
    "    \n",
    "    sorted_seqs = np.sort(seqs, axis=1)\n",
    "\n",
    "    arg_sort = np.argsort(seqs, axis=1)\n",
    "\n",
    "    \n",
    "    # add `START_TOKEN` to beginning of sorting \n",
    "    start_token = seqs_length\n",
    "    start_tokens = np.array([START_TOKEN] * len(arg_sort))[np.newaxis].T\n",
    "    arg_sort = np.hstack([start_tokens, arg_sort])\n",
    "\n",
    "    return objects, seqs, sorted_seqs, arg_sort, object_seqs, start_token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset 1\n",
    "vocab_size = 64\n",
    "dim = 32\n",
    "seqs_length = 10\n",
    "START_TOKEN = seqs_length\n",
    "n_seqs = 10_0000\n",
    "\n",
    "objects, seqs, sorted_seqs, arg_sort, object_seqs, start_token = create_sorting_dataset(vocab_size, dim, seqs_length, n_seqs)\n",
    "\n",
    "target = arg_sort[:, :-1]\n",
    "labels = arg_sort[:, 1:]\n",
    "\n",
    "data = {\n",
    "    'objects': objects, 'seqs': seqs, 'sorted_seqs': sorted_seqs, 'arg_sort': arg_sort,\n",
    "    'object_seqs': object_seqs, 'target': target, 'labels': labels, 'start_token': start_token\n",
    "    }\n",
    "\n",
    "np.save('object_sorting_datasets/task1_object_sort_dataset.npy', data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset 2 (same paramters, just re-generate objects randomly)\n",
    "vocab_size = 64\n",
    "dim = 32\n",
    "seqs_length = 10\n",
    "START_TOKEN = seqs_length\n",
    "n_seqs = 10_0000\n",
    "\n",
    "objects, seqs, sorted_seqs, arg_sort, object_seqs, start_token = create_sorting_dataset(vocab_size, dim, seqs_length, n_seqs)\n",
    "\n",
    "target = arg_sort[:, :-1]\n",
    "labels = arg_sort[:, 1:]\n",
    "\n",
    "data = {\n",
    "    'objects': objects, 'seqs': seqs, 'sorted_seqs': sorted_seqs, 'arg_sort': arg_sort,\n",
    "    'object_seqs': object_seqs, 'target': target, 'labels': labels, 'start_token': start_token\n",
    "    }\n",
    "\n",
    "np.save('object_sorting_datasets/task2_object_sort_dataset.npy', data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 2 Reshuffled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load('object_sorting_datasets/task2_object_sort_dataset.npy', allow_pickle=True).item()\n",
    "objects = data['objects']\n",
    "seqs = data['seqs']\n",
    "\n",
    "reshuffle = np.random.choice(64, size=64, replace=False)\n",
    "objects_ = objects[reshuffle]\n",
    "object_seqs_ = objects_[seqs]\n",
    "\n",
    "data['reshuffle'] = reshuffle\n",
    "data['objects'] = objects_\n",
    "data['object_seqs'] = object_seqs_\n",
    "\n",
    "np.save('object_sorting_datasets/task2_reshuffled_object_sort_dataset.npy', data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.16 64-bit ('relml')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "feb2622714ee4f3cfc5c273fa3fe6cf9410db521c7e03d7e619a7b4bef5cf3da"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
