{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsprites = np.load(\"/home/<user>/projects/disentanglement-multi-task/data/dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz\")\n",
    "dsprites = {key: dsprites[key] for key in ['imgs', 'latents_classes', 'latents_values']}\n",
    "dataset_x = dsprites['latents_values']\n",
    "dataset_x = torch.tensor(dataset_x).cuda().float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(658)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "class Masked(nn.Module):\n",
    "    def __init__(self, split, input_shape):\n",
    "        super(Mask, self).__init__()\n",
    "        self.split = split\n",
    "        self.in = input_shape\n",
    "        self.mask = torch.zeros(self.in, dtype=torch.float32)\n",
    "        self.mask[self.split] = 1\n",
    "    \n",
    "    def forward(self, input):\n",
    "        return self.mask*input\n",
    "        \n",
    "    \n",
    "\n",
    "\n",
    "networks = []\n",
    "multitask_targets = []\n",
    "for network in range(50):\n",
    "    network = nn.Sequential(\n",
    "        nn.Linear(6, 300),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(300, 300),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(300, 300),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(300, 300),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(300, 1),\n",
    "    )\n",
    "    \n",
    "    for module in network.children():\n",
    "        if isinstance(module, nn.Linear):\n",
    "            torch.nn.init.normal_(module.weight.data, 0, 1.)\n",
    "            \n",
    "    network = network.cuda()\n",
    "    targets = network(dataset_x)\n",
    "    targets = targets.detach().cpu().numpy()\n",
    "    multitask_targets += [targets]\n",
    "\n",
    "multitask_targets = np.concatenate(multitask_targets, 1)\n",
    "\n",
    "dsprites['multitask_targets'] = multitask_targets\n",
    "np.savez_compressed(\"dsprites_multitask.npz\", **dsprites)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 17.085754     6.195522   -27.487833    15.12213     -0.60677946\n",
      "  39.41825    -16.609615    -9.520086   -14.257808   -15.026125\n",
      " -16.258022   -22.908615    13.826408    12.454214    -1.5997301\n",
      " -11.580022   -22.412027    26.544628    19.908724    11.491505\n",
      "  27.053112    26.539207    -6.8763776    4.025102     1.8178561\n",
      "  -3.357193    -4.5485506    3.8842144    1.068662    -4.0464535\n",
      " -13.159656    -9.261389    -3.2104754    0.96721    -20.876127\n",
      "   7.3661785    2.0018985  -19.01292     -5.56429     11.124806\n",
      " -10.706724   -34.558735    12.645012    -5.399464    21.121984\n",
      "   4.085326   -17.37117     -0.3847853    8.716465    -9.86611   ]\n"
     ]
    }
   ],
   "source": [
    "A = np.load('dsprites_multitask.npz')\n",
    "print(A['multitask_targets'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.load('/Users/stella/projects/disentanglement-multi-task/data/test_dsets/dsprites/dsprites_multitask.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(737280, 50)\n"
     ]
    }
   ],
   "source": [
    "print(A['multitask_targets'].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_tasks=50\n",
    "num_true = 6\n",
    "\n",
    "\n",
    "#[1,1,0,0,0,0],\n",
    "#[1,1,1,0,0,0],\n",
    "#[1,1,1,1,0,0],etc\n",
    "grow_splits = []\n",
    "for i in range(num_tasks):\n",
    "    temp = np.zeros(num_true)\n",
    "    temp[:((i%(num_true-1))+2)] = 1.0\n",
    "    grow_splits.append(temp)\n",
    "\n",
    "\n",
    "#[1,1,0,0,0,0],\n",
    "#[0,0,1,1,0,0],\n",
    "#[0,0,0,0,1,1],etc.\n",
    "independent_splits_v1 = []\n",
    "div_size=2\n",
    "for i in range(num_tasks):\n",
    "    k = num_true//div_size\n",
    "    temp = np.zeros(num_true)\n",
    "    j = i%k\n",
    "    temp[div_size*j:div_size*(j+1)] = 1.0\n",
    "    independent_splits_v1.append(temp)\n",
    "    \n",
    "    \n",
    "independent_splits_v2 = []\n",
    "div_size=3\n",
    "for i in range(num_tasks):\n",
    "    k = num_true//div_size\n",
    "    temp = np.zeros(num_true)\n",
    "    j = i%k\n",
    "    temp[div_size*j:div_size*(j+1)] = 1.0\n",
    "    independent_splits_v2.append(temp)    \n",
    "\n",
    "#random\n",
    "#[1,0,1,1,0,0]\n",
    "#[0,0,1,0,1,1], etc. \n",
    "rs = np.random.RandomState(11)\n",
    "random_splits = []\n",
    "for i in range(num_tasks):\n",
    "    temp = np.zeros(num_true)\n",
    "    temp[rs.rand(len(temp))>0.5]=1.0\n",
    "    if(sum(temp)==0):\n",
    "        i = rs.randint(num_true)\n",
    "        temp[i]=1.0\n",
    "    random_splits.append(temp)\n",
    "    \n",
    "\n",
    "\n",
    "torch.manual_seed(658)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "class Masked(nn.Module):\n",
    "    def __init__(self, split, input_shape):\n",
    "        super(Masked, self).__init__()\n",
    "        self.split = split\n",
    "        self.input_shape = input_shape\n",
    "        self.mask = torch.zeros(self.input_shape, dtype=torch.float32).cuda()\n",
    "        self.mask[self.split] = 1\n",
    "    \n",
    "    def forward(self, input):\n",
    "        return self.mask*input\n",
    "        \n",
    "    \n",
    "SPLITS={\n",
    "    \"grow\":grow_splits,\n",
    "    \"independent_v1\":independent_splits_v1,\n",
    "    \"independent_v2\":independent_splits_v2,\n",
    "    \"random\":random_splits\n",
    "}\n",
    "    \n",
    "    \n",
    "    \n",
    "def save_splitted_dataset(split_type):\n",
    "    networks = []\n",
    "    multitask_targets = []\n",
    "    split_array = SPLITS[split_type]\n",
    "    for network in range(50):\n",
    "        network = nn.Sequential(\n",
    "            Masked(split_array[network], 6),\n",
    "            nn.Linear(6, 300),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(300, 300),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(300, 300),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(300, 300),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(300, 1),\n",
    "        )\n",
    "\n",
    "        for module in network.children():\n",
    "            if isinstance(module, nn.Linear):\n",
    "                torch.nn.init.normal_(module.weight.data, 0, 1.)\n",
    "\n",
    "        network = network.cuda()\n",
    "    \n",
    "        \n",
    "        targets = network(dataset_x)\n",
    "        targets = targets.detach().cpu().numpy()\n",
    "        multitask_targets += [targets]\n",
    "\n",
    "    multitask_targets = np.concatenate(multitask_targets, 1)\n",
    "\n",
    "    dsprites['multitask_targets'] = multitask_targets\n",
    "    np.savez_compressed(\"dsprites_multitask_{}_splits.npz\".format(split_type), **dsprites)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "grow\n",
      "independent_v1\n",
      "independent_v2\n",
      "random\n"
     ]
    }
   ],
   "source": [
    "print(\"grow\")\n",
    "save_splitted_dataset(\"grow\")\n",
    "print(\"independent_v1\")\n",
    "save_splitted_dataset(\"independent_v1\")\n",
    "print(\"independent_v2\")\n",
    "save_splitted_dataset(\"independent_v2\")\n",
    "print(\"random\")\n",
    "save_splitted_dataset(\"random\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 11.628903    -6.9355907  -18.089891   -21.871485   -51.237854\n",
      " -40.217728   -15.515591    -7.373314    13.756518    -7.5567546\n",
      "  16.42033     29.8013      -4.8933244    3.2823796  -23.642212\n",
      "  27.108505    -2.7973657  -14.442416    11.6041355   -6.3948245\n",
      "   9.82257     18.88667     13.748902     9.389744    14.644038\n",
      "  -1.185813   -22.49953    -18.795975   -17.900105   -33.955128\n",
      "  -8.455559   -11.455919    -0.29316142 -10.250931     6.180991\n",
      "  -7.1955323  -11.552702   -11.012647    20.333458    19.966867\n",
      "   8.528882    -8.024888    -1.8292567  -17.510044    -6.521972\n",
      "  -7.3071947   -6.7310696  -27.390554   -49.917305    -5.293173  ]\n"
     ]
    }
   ],
   "source": [
    "A = np.load('dsprites_multitask_random_splits.npz')\n",
    "print(A['multitask_targets'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}