{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "23026198",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "from datasets import load_dataset, Dataset, Array3D, Features\n",
    "import inox\n",
    "from priors.utils import *\n",
    "from tqdm import trange, tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d866cc43",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset('student/celebA', keep_in_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "47ba760e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for testing we only take 10,000\n",
    "trainset = dataset['train'].select(range(10000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "545181f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset_img = trainset['image']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4fe41886",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(218, 178, 3)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(trainset_img[1]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cfbc3faa",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = {\n",
    "    # Data\n",
    "    # 'duplicate': 2,\n",
    "    'corruption': 50,\n",
    "    'img_shape': (218, 178, 3),\n",
    "    # Architecture\n",
    "    'hid_channels': (128, 256, 384, 512),\n",
    "    'hid_blocks': (3, 3, 3, 3),\n",
    "    'kernel_size': (3, 3),\n",
    "    'emb_features': 256,\n",
    "    'heads': {3: 4},\n",
    "    'dropout': 0.1,\n",
    "    # Sampling\n",
    "    'sampler': 'ddpm',\n",
    "    'heuristic': None,\n",
    "    'sde': {'a': 1e-3, 'b': 1e2},\n",
    "    'discrete': 64,\n",
    "    'maxiter': 3,\n",
    "    # Training\n",
    "    'epochs': 64,\n",
    "    'batch_size': 256,\n",
    "    'scheduler': 'constant',\n",
    "    'lr_init': 1e-4,\n",
    "    'lr_end': 1e-6,\n",
    "    'lr_warmup': 0.0,\n",
    "    'optimizer': 'adam',\n",
    "    'weight_decay': None,\n",
    "    'clip': 1.0,\n",
    "    'ema_decay': 0.999,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6d2bab79",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDict:\n",
    "    def __init__(self, d):\n",
    "        self.d = d\n",
    "    def __getattr__(self, s):\n",
    "        return self.d[s]\n",
    "config = MyDict(CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9ac014d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = trainset.rename_column('image', 'x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8d901267",
   "metadata": {},
   "outputs": [],
   "source": [
    "def corrupt(rng, corruption, dataset: Dataset):\n",
    "    def transform(row):\n",
    "        x = np.asarray(row['x'])\n",
    "        A = rng.bernoulli(p = corruption / 100, shape = config.img_shape)\n",
    "        y = np.array(A * x)\n",
    "        return {'y': x}\n",
    "    \n",
    "    types = {\n",
    "        'y': Array3D(shape=config.img_shape, dtype='float32'),\n",
    "    }\n",
    "\n",
    "    return dataset.map(\n",
    "        transform,\n",
    "        remove_columns=['x'],\n",
    "        features=Features(types),\n",
    "        keep_in_memory=True,\n",
    "        num_proc=1\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "85c82185",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Parameter 'function'=<function corrupt.<locals>.transform at 0x74d20c403eb0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "549c023c621a4fb7b8f6750790c00900",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/10000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41d0e255fb7a4575ab20c012747579f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/10000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lap = 0\n",
    "seed = hash(('celeba', lap)) % (1<<16)\n",
    "rng = inox.random.PRNG(seed)\n",
    "trainset_corrupted = corrupt(rng, config.corruption, trainset)\n",
    "trainset = corrupt(rng, 0, trainset)\n",
    "trainset_corrupted.set_format(type='numpy', columns=['y'])\n",
    "trainset.set_format(type='numpy', columns=['y'])\n",
    "trainset = trainset.rename_column('y', 'x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bb21f9ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sharding\n",
    "jax.config.update('jax_threefry_partitionable', True)\n",
    "\n",
    "mesh = jax.sharding.Mesh(jax.devices(), 'i')\n",
    "replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n",
    "distributed = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('i'))\n",
    "\n",
    "# SDE\n",
    "sde = VESDE(**CONFIG.get('sde'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5e03a906",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = make_model_conditional(key = rng.split(), **CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "85b38768",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train(True)\n",
    "\n",
    "static, params, others = model.partition(nn.Parameter)\n",
    "\n",
    "# Objective\n",
    "objective = ConditionalDenoiserLoss(sde=sde)\n",
    "\n",
    "# Optimizer\n",
    "steps = config.epochs * len(dataset) // config.batch_size\n",
    "optimizer = Adam(\n",
    "    steps=steps,\n",
    "    scheduler = 'constant',\n",
    "    lr_init = config.lr_init,\n",
    "    lr_end = config.lr_end,\n",
    "    lr_warmup = config.lr_warmup,\n",
    "    weight_decay = config.weight_decay,\n",
    "    clip = config.clip,\n",
    "    )\n",
    "opt_state = optimizer.init(params)\n",
    "\n",
    "# EMA\n",
    "ema = EMA(decay=config.ema_decay)\n",
    "avrg = params\n",
    "\n",
    "# Training\n",
    "avrg, params, others, opt_state = jax.device_put((avrg, params, others, opt_state), replicated)\n",
    "\n",
    "@jax.jit\n",
    "@jax.vmap\n",
    "def augment(x, key):\n",
    "    keys = jax.random.split(key, 2)\n",
    "\n",
    "    x = random_flip(x, keys[0], axis=-2)\n",
    "    x = random_shake(x, keys[1], delta=4)\n",
    "\n",
    "    return x\n",
    "\n",
    "@jax.jit\n",
    "def ell(params, others, x, y_cond, key):\n",
    "    keys = jax.random.split(key, 3)\n",
    "\n",
    "    z = jax.random.normal(keys[0], shape=x.shape)\n",
    "    t = jax.random.beta(keys[1], a=3, b=3, shape=x.shape[:1])\n",
    "\n",
    "    return objective(static(params, others), x, z, t, y_cond, key=keys[2])\n",
    "\n",
    "@jax.jit\n",
    "def sgd_step(avrg, params, others, opt_state, x, y_cond, key):\n",
    "    loss, grads = jax.value_and_grad(ell)(params, others, x, y_cond, key)\n",
    "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    avrg = ema(avrg, params)\n",
    "\n",
    "    return loss, avrg, params, opt_state\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "87b2d99a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:02, ?it/s]                                               | 0/64 [00:00<?, ?it/s]\n",
      "  0%|                                                            | 0/64 [00:02<?, ?it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 3 for shapes (256, 56, 46, 512), (256, 55, 45, 384).",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 26\u001b[0m\n\u001b[1;32m     22\u001b[0m     y \u001b[38;5;241m=\u001b[39m augment(y, aug_key)\n\u001b[1;32m     23\u001b[0m     y \u001b[38;5;241m=\u001b[39m flatten(y)\n\u001b[0;32m---> 26\u001b[0m     loss, avrg, params, opt_state \u001b[38;5;241m=\u001b[39m \u001b[43msgd_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mavrg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mothers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrng\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     27\u001b[0m     losses\u001b[38;5;241m.\u001b[39mappend(loss)\n\u001b[1;32m     29\u001b[0m loss_train \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mstack(losses)\u001b[38;5;241m.\u001b[39mmean()\n",
      "    \u001b[0;31m[... skipping hidden 14 frame]\u001b[0m\n",
      "Cell \u001b[0;32mIn[14], line 49\u001b[0m, in \u001b[0;36msgd_step\u001b[0;34m(avrg, params, others, opt_state, x, y_cond, key)\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m     48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msgd_step\u001b[39m(avrg, params, others, opt_state, x, y_cond, key):\n\u001b[0;32m---> 49\u001b[0m     loss, grads \u001b[38;5;241m=\u001b[39m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue_and_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mell\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mothers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_cond\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     50\u001b[0m     updates, opt_state \u001b[38;5;241m=\u001b[39m optimizer\u001b[38;5;241m.\u001b[39mupdate(grads, opt_state, params)\n\u001b[1;32m     51\u001b[0m     params \u001b[38;5;241m=\u001b[39m optax\u001b[38;5;241m.\u001b[39mapply_updates(params, updates)\n",
      "    \u001b[0;31m[... skipping hidden 30 frame]\u001b[0m\n",
      "Cell \u001b[0;32mIn[14], line 45\u001b[0m, in \u001b[0;36mell\u001b[0;34m(params, others, x, y_cond, key)\u001b[0m\n\u001b[1;32m     42\u001b[0m z \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mnormal(keys[\u001b[38;5;241m0\u001b[39m], shape\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m     43\u001b[0m t \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mbeta(keys[\u001b[38;5;241m1\u001b[39m], a\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, b\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, shape\u001b[38;5;241m=\u001b[39mx\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m---> 45\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mobjective\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstatic\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mothers\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_cond\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeys\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/inox/api.py:59\u001b[0m, in \u001b[0;36mouter.<locals>.wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fun)\n\u001b[1;32m     57\u001b[0m \u001b[38;5;129m@api_boundary\u001b[39m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 59\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tree_unmask(\u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_mask\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_mask\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n",
      "    \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/inox/api.py:42\u001b[0m, in \u001b[0;36minner.<locals>.wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fun)\n\u001b[1;32m     40\u001b[0m \u001b[38;5;129m@api_boundary\u001b[39m\n\u001b[1;32m     41\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 42\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tree_mask(\u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_unmask\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_unmask\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m/workspace/diffusion-priors/priors/diffusion.py:431\u001b[0m, in \u001b[0;36mConditionalDenoiserLoss.__call__\u001b[0;34m(self, model, x0, z, t, y_cond, key)\u001b[0m\n\u001b[1;32m    429\u001b[0m xt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msde(x0, z, t)\n\u001b[1;32m    430\u001b[0m \u001b[38;5;66;03m# we can give the corruption matrix as well? should we?\u001b[39;00m\n\u001b[0;32m--> 431\u001b[0m ft \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mxt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigma_t\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_cond\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    433\u001b[0m error \u001b[38;5;241m=\u001b[39m ft \u001b[38;5;241m-\u001b[39m x0\n\u001b[1;32m    435\u001b[0m \u001b[38;5;66;03m# what about norm 1, they said some stuff in the pallette class\u001b[39;00m\n",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/inox/api.py:59\u001b[0m, in \u001b[0;36mouter.<locals>.wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fun)\n\u001b[1;32m     57\u001b[0m \u001b[38;5;129m@api_boundary\u001b[39m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 59\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tree_unmask(\u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_mask\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_mask\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n",
      "    \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/inox/api.py:42\u001b[0m, in \u001b[0;36minner.<locals>.wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fun)\n\u001b[1;32m     40\u001b[0m \u001b[38;5;129m@api_boundary\u001b[39m\n\u001b[1;32m     41\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 42\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tree_mask(\u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_unmask\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtree_unmask\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m/workspace/diffusion-priors/priors/diffusion.py:348\u001b[0m, in \u001b[0;36mConditionalDenoiser.__call__\u001b[0;34m(self, xt, sigma_t, y, key)\u001b[0m\n\u001b[1;32m    344\u001b[0m c_noise \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mlog(sigma_t)\n\u001b[1;32m    346\u001b[0m c_skip, c_out, c_in \u001b[38;5;241m=\u001b[39m c_skip[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m], c_out[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m], c_in[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m]\n\u001b[0;32m--> 348\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m c_skip \u001b[38;5;241m*\u001b[39m xt \u001b[38;5;241m+\u001b[39m c_out \u001b[38;5;241m*\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mc_in\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mxt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43memb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mc_noise\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/workspace/diffusion-priors/priors/utils.py:149\u001b[0m, in \u001b[0;36mConditionalFlatUNet.__call__\u001b[0;34m(self, x, t, y_cond, key)\u001b[0m\n\u001b[1;32m    146\u001b[0m y_cond \u001b[38;5;241m=\u001b[39m unflatten(y_cond, width \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m178\u001b[39m, height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m218\u001b[39m)\n\u001b[1;32m    147\u001b[0m res \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mconcatenate((x, y_cond), axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 149\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;66;03m# TODO: Dropping some of the channels\u001b[39;00m\n\u001b[1;32m    152\u001b[0m res \u001b[38;5;241m=\u001b[39m res[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, :\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_returned_channels]\n",
      "File \u001b[0;32m/workspace/diffusion-priors/priors/nn.py:241\u001b[0m, in \u001b[0;36mUNet.__call__\u001b[0;34m(self, x, t, key)\u001b[0m\n\u001b[1;32m    238\u001b[0m y \u001b[38;5;241m=\u001b[39m memory\u001b[38;5;241m.\u001b[39mpop()\n\u001b[1;32m    240\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m y:\n\u001b[0;32m--> 241\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconcatenate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    243\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m block \u001b[38;5;129;01min\u001b[39;00m blocks:\n\u001b[1;32m    244\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(block, (ResBlock, AttBlock)):\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:4648\u001b[0m, in \u001b[0;36mconcatenate\u001b[0;34m(arrays, axis, dtype)\u001b[0m\n\u001b[1;32m   4646\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m16\u001b[39m\n\u001b[1;32m   4647\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(arrays_out) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m-> 4648\u001b[0m   arrays_out \u001b[38;5;241m=\u001b[39m [lax\u001b[38;5;241m.\u001b[39mconcatenate(arrays_out[i:i\u001b[38;5;241m+\u001b[39mk], axis)\n\u001b[1;32m   4649\u001b[0m                 \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mlen\u001b[39m(arrays_out), k)]\n\u001b[1;32m   4650\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m arrays_out[\u001b[38;5;241m0\u001b[39m]\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:4648\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   4646\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m16\u001b[39m\n\u001b[1;32m   4647\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(arrays_out) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m-> 4648\u001b[0m   arrays_out \u001b[38;5;241m=\u001b[39m [\u001b[43mlax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconcatenate\u001b[49m\u001b[43m(\u001b[49m\u001b[43marrays_out\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m:\u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43mk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   4649\u001b[0m                 \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mlen\u001b[39m(arrays_out), k)]\n\u001b[1;32m   4650\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m arrays_out[\u001b[38;5;241m0\u001b[39m]\n",
      "    \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py:6655\u001b[0m, in \u001b[0;36m_concatenate_shape_rule\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m   6651\u001b[0m   msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot concatenate arrays with shapes that differ in dimensions \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   6652\u001b[0m          \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mother than the one being concatenated: concatenating along \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   6653\u001b[0m          \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdimension \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m for shapes \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   6654\u001b[0m   shapes \u001b[38;5;241m=\u001b[39m [operand\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;28;01mfor\u001b[39;00m operand \u001b[38;5;129;01min\u001b[39;00m operands]\n\u001b[0;32m-> 6655\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg\u001b[38;5;241m.\u001b[39mformat(dimension, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mstr\u001b[39m, shapes))))\n\u001b[1;32m   6657\u001b[0m concat_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(o\u001b[38;5;241m.\u001b[39mshape[dimension] \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m operands)\n\u001b[1;32m   6658\u001b[0m ex_shape \u001b[38;5;241m=\u001b[39m operands[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mshape\n",
      "\u001b[0;31mTypeError\u001b[0m: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 3 for shapes (256, 56, 46, 512), (256, 55, 45, 384)."
     ]
    }
   ],
   "source": [
    "for epoch in (bar := trange(config.epochs, ncols=88)):\n",
    "    loader = trainset.shuffle(seed=seed + lap * config.epochs + epoch).iter(\n",
    "        batch_size=config.batch_size, drop_last_batch=True\n",
    "    )\n",
    "    loader_yA = trainset_corrupted.shuffle(seed=seed + lap * config.epochs + epoch).iter(\n",
    "        batch_size=config.batch_size, drop_last_batch=True\n",
    "    )\n",
    "    losses = []\n",
    "\n",
    "    for batch_x, batch_y in tqdm(zip(prefetch(loader), prefetch(loader_yA))):\n",
    "\n",
    "        x = batch_x['x']\n",
    "        y = batch_y['y']\n",
    "\n",
    "        aug_key = rng.split(len(x))\n",
    "\n",
    "        x = jax.device_put(x, distributed)\n",
    "        x = augment(x, aug_key)\n",
    "        x = flatten(x)\n",
    "        \n",
    "        y = jax.device_put(y, distributed)\n",
    "        y = augment(y, aug_key)\n",
    "        y = flatten(y)\n",
    "\n",
    "\n",
    "        loss, avrg, params, opt_state = sgd_step(avrg, params, others, opt_state, x, y, key=rng.split())\n",
    "        losses.append(loss)\n",
    "\n",
    "    loss_train = np.stack(losses).mean()\n",
    "\n",
    "    ## Validation\n",
    "    # loader = testset.iter(batch_size=config.batch_size, drop_last_batch=True)\n",
    "    # loader_yA = testset_corrupted_yA.iter(batch_size=config.batch_size, drop_last_batch=True)\n",
    "    # losses = []\n",
    "\n",
    "    # for batch_x, batch_y in zip(prefetch(loader), prefetch(loader_yA)):\n",
    "    #     x = batch_x['x']\n",
    "    #     x = jax.device_put(x, distributed)\n",
    "    #     x = flatten(x)\n",
    "\n",
    "    #     y_cond = batch_y['y']\n",
    "    #     y_cond = jax.device_put(y_cond, distributed)\n",
    "    #     y_cond = flatten(y_cond)\n",
    "\n",
    "    #     loss = ell(avrg, others, x, y_cond, key=rng.split())\n",
    "    #     losses.append(loss)\n",
    "\n",
    "    # loss_val = np.stack(losses).mean()\n",
    "\n",
    "    # bar.set_postfix(loss=loss_train, loss_val=loss_val)\n",
    "    bar.set_postfix(loss=loss_train)\n",
    "\n",
    "    # ## Eval\n",
    "    # if (epoch + 1) % 16 == 0:\n",
    "    #     model = static(avrg, others)\n",
    "    #     model.train(False)\n",
    "\n",
    "    #     x = sample_conditional(\n",
    "    #         model=model,\n",
    "    #         y_cond=y_eval,\n",
    "    #         key=rng.split(),\n",
    "    #         shard=True,\n",
    "    #         sampler=config.sampler,\n",
    "    #         steps=config.discrete,\n",
    "    #         maxiter=config.maxiter,\n",
    "    #     )\n",
    "    #     x = x.reshape(2, 2, 320, 320, 1)\n",
    "\n",
    "    #     # run.log({\n",
    "    #     #     'loss': loss_train,\n",
    "    #     #     'loss_val': loss_val,\n",
    "    #     #     'samples': wandb.Image(to_pil(x)),\n",
    "    #     # })\n",
    "    # else:\n",
    "    #     pass\n",
    "        # run.log({\n",
    "        #     'loss': loss_train,\n",
    "        #     'loss_val': loss_val,\n",
    "        # })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "980a08b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = trainset.shuffle(seed=seed + lap * config.epochs + epoch).iter(\n",
    "        batch_size=config.batch_size, drop_last_batch=True\n",
    "    )\n",
    "loader_yA = trainset_corrupted.shuffle(seed=seed + lap * config.epochs + epoch).iter(\n",
    "        batch_size=config.batch_size, drop_last_batch=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "a4e52ad1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['y'],\n",
       "    num_rows: 10000\n",
       "})"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset_corrupted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cee1afb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['x'],\n",
       "    num_rows: 10000\n",
       "})"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "05e662df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Iterable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "9366cd95",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception in thread Thread-248:\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n",
      "    self.run()\n",
      "  File \"/workspace/diffusion-priors/priors/data.py\", line 27, in run\n",
      "    for item in self.iterable:\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py\", line 2423, in iter\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "    yield self._getitem(\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py\", line 2785, in _getitem\n",
      "    formatted_output = format_table(\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 629, in format_table\n",
      "    return formatter(pa_table, query_type=query_type)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 400, in __call__\n",
      "    return self.format_batch(pa_table)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 448, in format_batch\n",
      "    batch = self.python_arrow_extractor().extract_batch(pa_table)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 150, in extract_batch\n",
      "    return pa_table.to_pydict()\n",
      "  File \"pyarrow/table.pxi\", line 2308, in pyarrow.lib._Tabular.to_pydict\n",
      "  File \"pyarrow/table.pxi\", line 1380, in pyarrow.lib.ChunkedArray.to_pylist\n",
      "TypeError: ArrayExtensionArray.to_pylist() got an unexpected keyword argument 'maps_as_pydicts'\n"
     ]
    }
   ],
   "source": [
    "p = prefetch(loader)\n",
    "q = prefetch(loader_yA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "ba3fbf4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, j in zip(p, q):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "6605505d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception in thread Thread-246:\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n",
      "    self.run()\n",
      "  File \"/workspace/diffusion-priors/priors/data.py\", line 27, in run\n",
      "    for item in self.iterable:\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py\", line 2423, in iter\n",
      "    yield self._getitem(\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py\", line 2785, in _getitem\n",
      "    formatted_output = format_table(\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 629, in format_table\n",
      "    return formatter(pa_table, query_type=query_type)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 400, in __call__\n",
      "    return self.format_batch(pa_table)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 448, in format_batch\n",
      "    batch = self.python_arrow_extractor().extract_batch(pa_table)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py\", line 150, in extract_batch\n",
      "    return pa_table.to_pydict()\n",
      "  File \"pyarrow/table.pxi\", line 2308, in pyarrow.lib._Tabular.to_pydict\n",
      "  File \"pyarrow/table.pxi\", line 1380, in pyarrow.lib.ChunkedArray.to_pylist\n",
      "TypeError: ArrayExtensionArray.to_pylist() got an unexpected keyword argument 'maps_as_pydicts'\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[58], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_x, batch_y \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(prefetch(loader), prefetch(loader_yA)):\n\u001b[1;32m      2\u001b[0m     \u001b[38;5;66;03m# print('x')\u001b[39;00m\n\u001b[1;32m      3\u001b[0m     \u001b[38;5;66;03m# x = batch_x['x']\u001b[39;00m\n\u001b[1;32m      4\u001b[0m     \u001b[38;5;66;03m# y = batch_y['y']\u001b[39;00m\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;66;03m# aug_key = rng.split(len(x))\u001b[39;00m\n\u001b[1;32m      6\u001b[0m     \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;66;03m# x = jax.device_put(x, distributed)\u001b[39;00m\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;66;03m# x = augment(x, aug_key)\u001b[39;00m\n\u001b[1;32m      9\u001b[0m     \u001b[38;5;66;03m# x = flatten(x)\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[38;5;66;03m# loss, avrg, params, opt_state = sgd_step(avrg, params, others, opt_state, x, y, key=rng.split())\u001b[39;00m\n\u001b[1;32m     17\u001b[0m     \u001b[38;5;66;03m# losses.append(loss)\u001b[39;00m\n",
      "File \u001b[0;32m/workspace/diffusion-priors/priors/data.py:21\u001b[0m, in \u001b[0;36mprefetch.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m---> 21\u001b[0m     item \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mqueue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     22\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m item \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mend:\n\u001b[1;32m     23\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n",
      "File \u001b[0;32m/usr/lib/python3.10/queue.py:171\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m    169\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    170\u001b[0m     \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_qsize():\n\u001b[0;32m--> 171\u001b[0m         \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnot_empty\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    172\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m    173\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m must be a non-negative number\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m/usr/lib/python3.10/threading.py:320\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    318\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:    \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m    319\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 320\u001b[0m         \u001b[43mwaiter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    321\u001b[0m         gotit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m    322\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for batch_x, batch_y in zip(prefetch(loader), prefetch(loader_yA)):\n",
    "    # print('x')\n",
    "    # x = batch_x['x']\n",
    "    # y = batch_y['y']\n",
    "    # aug_key = rng.split(len(x))\n",
    "    pass\n",
    "    # x = jax.device_put(x, distributed)\n",
    "    # x = augment(x, aug_key)\n",
    "    # x = flatten(x)\n",
    "    \n",
    "    # y = jax.device_put(y, distributed)\n",
    "    # y = augment(y, aug_key)\n",
    "    # y = flatten(y)\n",
    "\n",
    "\n",
    "    # loss, avrg, params, opt_state = sgd_step(avrg, params, others, opt_state, x, y, key=rng.split())\n",
    "    # losses.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "b15a2191",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<zip at 0x7a7738359fc0>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zip(prefetch(loader), prefetch(loader_yA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdbfc623",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
