{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate Random Object Sorting Tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sorting_dataset(objects, seqs_length, n_seqs):\n",
    "\n",
    "    n_objects = len(objects)\n",
    "\n",
    "    # generate random permutations of length `seqs_length` out of `vocab_size`\n",
    "    seqs = np.array([np.random.choice(range(n_objects), size=seqs_length, replace=False) for _ in range(n_seqs)])\n",
    "    \n",
    "    # remove duplicate seqs (although very unlikely)\n",
    "    _, unique_seq_idxs = np.unique(seqs, axis=0, return_inverse=True)\n",
    "    seqs = seqs[unique_seq_idxs]\n",
    "\n",
    "    # create object sequences\n",
    "    object_seqs = objects[seqs]\n",
    "    \n",
    "    sorted_seqs = np.sort(seqs, axis=1)\n",
    "\n",
    "    arg_sort = np.argsort(seqs, axis=1)\n",
    "\n",
    "    \n",
    "    # add `START_TOKEN` to beginning of sorting \n",
    "    start_token = seqs_length\n",
    "    start_tokens = np.array([start_token] * len(arg_sort))[np.newaxis].T\n",
    "    arg_sort = np.hstack([start_tokens, arg_sort])\n",
    "\n",
    "    return seqs, sorted_seqs, arg_sort, object_seqs, start_token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 1\n",
    "\n",
    "generate `vocab_size` objects as random gaussian vectors with dimension `dim`. Associate a random ordering to them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset 1\n",
    "vocab_size = 64\n",
    "dim = 8\n",
    "seqs_length = 10\n",
    "n_seqs = 10_0000\n",
    "\n",
    "objects = np.random.normal(size=(vocab_size, dim))\n",
    "\n",
    "seqs, sorted_seqs, arg_sort, object_seqs, start_token = create_sorting_dataset(objects, seqs_length, n_seqs)\n",
    "\n",
    "target = arg_sort[:, :-1]\n",
    "labels = arg_sort[:, 1:]\n",
    "\n",
    "data = {\n",
    "    'objects': objects, 'seqs': seqs, 'sorted_seqs': sorted_seqs, 'arg_sort': arg_sort,\n",
    "    'object_seqs': object_seqs, 'target': target, 'labels': labels, 'start_token': start_token\n",
    "    }\n",
    "\n",
    "np.save('object_sorting_datasets/task1_object_sort_dataset.npy', data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 2\n",
    "\n",
    "independently generate another random set of objects and associated ordering."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset 2 (same paramters, just re-generate objects randomly)\n",
    "vocab_size = 64\n",
    "dim = 8\n",
    "seqs_length = 10\n",
    "n_seqs = 10_0000\n",
    "\n",
    "objects = np.random.normal(size=(vocab_size, dim))\n",
    "\n",
    "seqs, sorted_seqs, arg_sort, object_seqs, start_token = create_sorting_dataset(objects, seqs_length, n_seqs)\n",
    "\n",
    "target = arg_sort[:, :-1]\n",
    "labels = arg_sort[:, 1:]\n",
    "\n",
    "data = {\n",
    "    'objects': objects, 'seqs': seqs, 'sorted_seqs': sorted_seqs, 'arg_sort': arg_sort,\n",
    "    'object_seqs': object_seqs, 'target': target, 'labels': labels, 'start_token': start_token\n",
    "    }\n",
    "\n",
    "np.save('object_sorting_datasets/task2_object_sort_dataset.npy', data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 2 Reshuffled\n",
    "\n",
    "Reshuffle the order of the objects in task 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load('object_sorting_datasets/task2_object_sort_dataset.npy', allow_pickle=True).item()\n",
    "objects = data['objects']\n",
    "seqs = data['seqs']\n",
    "\n",
    "reshuffle = np.random.choice(64, size=64, replace=False)\n",
    "objects_ = objects[reshuffle]\n",
    "object_seqs_ = objects_[seqs]\n",
    "\n",
    "data['reshuffle'] = reshuffle\n",
    "data['objects'] = objects_\n",
    "data['object_seqs'] = object_seqs_\n",
    "\n",
    "np.save('object_sorting_datasets/task2_reshuffled_object_sort_dataset.npy', data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate objects with attribute-product structure\n",
    "\n",
    "Generate two attributes as random gaussian vectors and associate an ordering to them. Then generate objects as cartesian products. Associate an ordering to them where one attribute forms a primary key, and the other forms a secondary key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate objects with attribute-product structure\n",
    "\n",
    "attr1_n_objects = 4 # number of possible values for attribute 1\n",
    "attr1_embedding_dim = 4 # dimension of vector representation of attribute 1 \n",
    "attr2_n_objects = 12 # number of possible values for attribute 2\n",
    "attr2_embedding_dim = 8 # dimension of vector representation of attribute 2\n",
    "\n",
    "# generate vector representations of the two attributes\n",
    "attr1_objects = np.random.normal(size=(attr1_n_objects,  attr1_embedding_dim))\n",
    "attr2_objects = np.random.normal(size=(attr2_n_objects,  attr2_embedding_dim))\n",
    "\n",
    "# generate attribute-product objects and ordering \n",
    "object_products = [(attr1, attr2) for attr1 in range(attr1_n_objects) for attr2 in range(attr2_n_objects)]\n",
    "\n",
    "objects = []\n",
    "for attr1, attr2 in object_products:\n",
    "    attr1_object = attr1_objects[attr1] # get vector representation of attribute 1\n",
    "    attr2_object = attr2_objects[attr2] # get vector representation of attribute 2\n",
    "    object_ = np.concatenate([attr1_object, attr2_object]) # stack attributes to create object\n",
    "    objects.append(object_)\n",
    "\n",
    "objects = np.array(objects)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(48, 12)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "objects.shape # (n_objects, object_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate object sorting dataset\n",
    "seqs_length = 10\n",
    "n_seqs = 10_0000\n",
    "\n",
    "seqs, sorted_seqs, arg_sort, object_seqs, start_token = create_sorting_dataset(objects, seqs_length, n_seqs)\n",
    "\n",
    "target = arg_sort[:, :-1]\n",
    "labels = arg_sort[:, 1:]\n",
    "\n",
    "data = {\n",
    "    'objects': objects, 'attr1_objects': attr1_objects, 'attr2_objects': attr2_objects, \n",
    "    'seqs': seqs, 'sorted_seqs': sorted_seqs, 'arg_sort': arg_sort,\n",
    "    'object_seqs': object_seqs, 'target': target, 'labels': labels, 'start_token': start_token\n",
    "    }\n",
    "\n",
    "np.save('object_sorting_datasets/product_structure_object_sort_dataset.npy', data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reshuffle Attribute-Product Structure Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "attr1_reshuffle = np.random.choice(attr1_n_objects, size=attr1_n_objects, replace=False)\n",
    "attr2_reshuffle = np.arange(attr2_n_objects) # identity permutation\n",
    "\n",
    "attr1_objects_reshuffled = attr1_objects[attr1_reshuffle]\n",
    "attr2_objects_reshuffled = attr2_objects[attr2_reshuffle]\n",
    "\n",
    "# generate attribute-product objects and ordering \n",
    "object_products = [(attr1, attr2) for attr1 in range(attr1_n_objects) for attr2 in range(attr2_n_objects)]\n",
    "\n",
    "objects_reshuffled = []\n",
    "for attr1, attr2 in object_products:\n",
    "    attr1_object = attr1_objects[attr1] # get vector representation of attribute 1\n",
    "    attr2_object = attr2_objects[attr2] # get vector representation of attribute 2\n",
    "    object_ = np.concatenate([attr1_object, attr2_object]) # stack attributes to create object\n",
    "    objects_reshuffled.append(object_)\n",
    "\n",
    "objects_reshuffled = np.array(objects)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['attr1_reshuffle'] = attr1_reshuffle\n",
    "data['attr2_reshuffle'] = attr2_reshuffle\n",
    "data['objects'] = objects_reshuffled\n",
    "\n",
    "object_seqs_reshuffled = objects_reshuffled[seqs]\n",
    "data['object_seqs'] = object_seqs_reshuffled\n",
    "\n",
    "data['attr1_objects'] = attr1_objects_reshuffled\n",
    "data['attr2_objects'] = attr2_objects_reshuffled\n",
    "\n",
    "np.save('object_sorting_datasets/product_structure_reshuffled_object_sort_dataset.npy', data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.16 64-bit ('relml')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "feb2622714ee4f3cfc5c273fa3fe6cf9410db521c7e03d7e619a7b4bef5cf3da"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
