{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import time\n",
    "from typing import Union\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from einops import rearrange\n",
    "from diffusers.schedulers.scheduling_ddpm import DDPMScheduler\n",
    "\n",
    "from hypnettorch.mnets import MLP\n",
    "from hypnettorch.hnets import HMLP\n",
    "\n",
    "from env import PushTEnv, NormalizeActionWrapper, SpaceConversionWrapper, StateStackWrapper\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DECODER_SIZE_DICT = {\n",
    "    'xs': [50, 50],\n",
    "    's': [100, 100],\n",
    "    'm': [200, 200],\n",
    "    'l': [400, 400]\n",
    "}\n",
    "\n",
    "def vae_loss(x, x_recon, z_mean, z_logvar, kl_coeff = 1e-6):\n",
    "    recon_loss = nn.MSELoss()(x_recon, x)\n",
    "    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())\n",
    "    return {\n",
    "        \"recon_loss\" : recon_loss,\n",
    "        \"kl_loss\": kl_coeff * kl_loss\n",
    "    }\n",
    "\n",
    "class VAEHyperNetModel(nn.Module):\n",
    "    def __init__(\n",
    "            self, \n",
    "            policy, \n",
    "            state_dim, \n",
    "            action_dim, \n",
    "            latent_dim,\n",
    "            vae_decoder_size, \n",
    "            traj_len, \n",
    "            stochastic_decoder = False, \n",
    "            kl_coeff = 1e-6\n",
    "        ):\n",
    "        super().__init__()\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        self.latent_dim = latent_dim\n",
    "        self.vae_decoder_size = vae_decoder_size\n",
    "        self.traj_len = traj_len\n",
    "        self.stochastic_decoder = stochastic_decoder\n",
    "        self.kl_coeff = kl_coeff\n",
    "        desired_shape = policy.param_shapes\n",
    "        self.hnet = HMLP(desired_shape, cond_in_size=0, uncond_in_size=latent_dim, layers=DECODER_SIZE_DICT[self.vae_decoder_size])\n",
    "        if self.stochastic_decoder:\n",
    "            self.log_var = nn.Parameter(torch.zeros(1))\n",
    "        self.criterion = nn.MSELoss()\n",
    "\n",
    "        # Encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(self.traj_len * (self.state_dim + self.action_dim), 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, self.latent_dim * 2)  # output both mean and log variance\n",
    "        )\n",
    "\n",
    "    def encode(self, x):\n",
    "        # Encoding\n",
    "        x_flattened = rearrange(x, \"b t sa -> b (t sa)\")\n",
    "        encoded = self.encoder(x_flattened)\n",
    "        mu, log_var = torch.chunk(encoded, 2, dim=-1)  # split the encoder output into mu and log_var components\n",
    "        return mu, log_var\n",
    "\n",
    "    def decode(self, x):\n",
    "        weights = self.hnet.forward(uncond_input=x)\n",
    "        return weights\n",
    "\n",
    "    def forward(self, batch_obs, policy, batch_act):\n",
    "        batch_size = batch_act.shape[0]\n",
    "        x = torch.cat([batch_obs, batch_act], -1)\n",
    "\n",
    "        mu, log_var = self.encode(x)\n",
    "\n",
    "        z = reparameterize(mu, log_var)\n",
    "\n",
    "        weights_mean = self.hnet.forward(uncond_input=z)\n",
    "        sampled_actions = []\n",
    "        for i in range(batch_size):\n",
    "            weight_mean_ = weights_mean[i]\n",
    "            if self.stochastic_decoder:\n",
    "                sampled_weights = []\n",
    "                for layer_mean in weight_mean_:\n",
    "                    layer_log_var = self.log_var\n",
    "\n",
    "                    # Compute standard deviation from log_var\n",
    "                    std = torch.exp(0.5 * layer_log_var)\n",
    "\n",
    "                    # Sample epsilon from standard normal distribution\n",
    "                    epsilon = torch.randn_like(std)\n",
    "                    \n",
    "                    # Reparameterization: sample from N(mean, var)\n",
    "                    sample = layer_mean + epsilon * std\n",
    "\n",
    "                    sampled_weights.append(sample)\n",
    "            \n",
    "            else:\n",
    "                sampled_weights = weight_mean_\n",
    "\n",
    "            # get actions from policy\n",
    "            sampled_actions.append(policy(batch_obs[i], sampled_weights))\n",
    "\n",
    "        reconstructed_actions = torch.stack(sampled_actions)\n",
    "\n",
    "        loss = vae_loss(batch_act, reconstructed_actions, mu, log_var, self.kl_coeff)\n",
    "\n",
    "        return reconstructed_actions, loss\n",
    "\n",
    "    def get_policy_weights(self, device, data = None):\n",
    "        if data is None:\n",
    "            # sample z from normal dist\n",
    "            z = reparameterize(torch.Tensor([[0]*self.latent_dim]), torch.Tensor([[0]*self.latent_dim])).to(device)\n",
    "        else:\n",
    "            states = data[\"states\"]  # [8, 64, 17]\n",
    "            actions = data[\"actions\"]  # [8, 64, 6]\n",
    "            x = torch.cat([states, actions], -1).to(device)\n",
    "            x_flattened = rearrange(x, \"b t sa -> b (t sa)\")\n",
    "            encoded = self.encoder(x_flattened)\n",
    "            mu, log_var = torch.chunk(encoded, 2, dim=-1)  # split the encoder output into mu and log_var components\n",
    "            z = reparameterize(mu, log_var)  # [8, latent_dim]\n",
    "\n",
    "        hnet_out = self.hnet.forward(uncond_input=z)  # [8, policy_weights]\n",
    "        # weights_mean = hnet_out[:-1]\n",
    "        # weights_logvar = hnet_out[-1]\n",
    "        # weights = reparameterize(weights_mean, weights_logvar)\n",
    "\n",
    "        return hnet_out\n",
    "\n",
    "\n",
    "\n",
    "def reparameterize(mu, log_var):\n",
    "    std = torch.exp(0.5 * log_var)\n",
    "    eps = torch.randn_like(std)\n",
    "    return mu + eps * std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SinusoidalPosEmb(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "\n",
    "    def forward(self, x):\n",
    "        device = x.device\n",
    "        half_dim = self.dim // 2\n",
    "        emb = math.log(10000) / (half_dim - 1)\n",
    "        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n",
    "        emb = x[:, None] * emb[None, :]\n",
    "        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n",
    "        return emb\n",
    "\n",
    "\n",
    "class Downsample1d(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "class Upsample1d(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "\n",
    "class Conv1dBlock(nn.Module):\n",
    "    '''\n",
    "        Conv1d --> GroupNorm --> Mish\n",
    "    '''\n",
    "\n",
    "    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):\n",
    "        super().__init__()\n",
    "\n",
    "        self.block = nn.Sequential(\n",
    "            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),\n",
    "            nn.GroupNorm(n_groups, out_channels),\n",
    "            nn.Mish(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.block(x)\n",
    "\n",
    "\n",
    "class ConditionalResidualBlock1D(nn.Module):\n",
    "    def __init__(self,\n",
    "            in_channels,\n",
    "            out_channels,\n",
    "            cond_dim,\n",
    "            kernel_size=3,\n",
    "            n_groups=8):\n",
    "        super().__init__()\n",
    "\n",
    "        self.blocks = nn.ModuleList([\n",
    "            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),\n",
    "            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),\n",
    "        ])\n",
    "\n",
    "        # FiLM modulation https://arxiv.org/abs/1709.07871\n",
    "        # predicts per-channel scale and bias\n",
    "        cond_channels = out_channels * 2\n",
    "        self.out_channels = out_channels\n",
    "        self.cond_encoder = nn.Sequential(\n",
    "            nn.Mish(),\n",
    "            nn.Linear(cond_dim, cond_channels),\n",
    "            nn.Unflatten(-1, (-1, 1))\n",
    "        )\n",
    "\n",
    "        # make sure dimensions compatible\n",
    "        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \\\n",
    "            if in_channels != out_channels else nn.Identity()\n",
    "\n",
    "    def forward(self, x, cond):\n",
    "        '''\n",
    "            x : [ batch_size x in_channels x horizon ]\n",
    "            cond : [ batch_size x cond_dim]\n",
    "\n",
    "            returns:\n",
    "            out : [ batch_size x out_channels x horizon ]\n",
    "        '''\n",
    "        out = self.blocks[0](x)\n",
    "        embed = self.cond_encoder(cond)\n",
    "\n",
    "        embed = embed.reshape(\n",
    "            embed.shape[0], 2, self.out_channels, 1)\n",
    "        scale = embed[:,0,...]\n",
    "        bias = embed[:,1,...]\n",
    "        out = scale * out + bias\n",
    "\n",
    "        out = self.blocks[1](out)\n",
    "        out = out + self.residual_conv(x)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ConditionalUnet1D(nn.Module):\n",
    "    def __init__(self,\n",
    "        input_dim,\n",
    "        global_cond_dim,\n",
    "        diffusion_step_embed_dim=256,\n",
    "        down_dims=[256,512,1024],\n",
    "        kernel_size=5,\n",
    "        n_groups=8\n",
    "        ):\n",
    "        \"\"\"\n",
    "        input_dim: Dim of actions.\n",
    "        global_cond_dim: Dim of global conditioning applied with FiLM\n",
    "          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim\n",
    "        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k\n",
    "        down_dims: Channel size for each UNet level.\n",
    "          The length of this array determines numebr of levels.\n",
    "        kernel_size: Conv kernel size\n",
    "        n_groups: Number of groups for GroupNorm\n",
    "        \"\"\"\n",
    "\n",
    "        super().__init__()\n",
    "        all_dims = [input_dim] + list(down_dims)\n",
    "        start_dim = down_dims[0]\n",
    "\n",
    "        dsed = diffusion_step_embed_dim\n",
    "        diffusion_step_encoder = nn.Sequential(\n",
    "            SinusoidalPosEmb(dsed),\n",
    "            nn.Linear(dsed, dsed * 4),\n",
    "            nn.Mish(),\n",
    "            nn.Linear(dsed * 4, dsed),\n",
    "        )\n",
    "        cond_dim = dsed + global_cond_dim\n",
    "\n",
    "        in_out = list(zip(all_dims[:-1], all_dims[1:]))\n",
    "        mid_dim = all_dims[-1]\n",
    "        self.mid_modules = nn.ModuleList([\n",
    "            ConditionalResidualBlock1D(\n",
    "                mid_dim, mid_dim, cond_dim=cond_dim,\n",
    "                kernel_size=kernel_size, n_groups=n_groups\n",
    "            ),\n",
    "            ConditionalResidualBlock1D(\n",
    "                mid_dim, mid_dim, cond_dim=cond_dim,\n",
    "                kernel_size=kernel_size, n_groups=n_groups\n",
    "            ),\n",
    "        ])\n",
    "\n",
    "        down_modules = nn.ModuleList([])\n",
    "        for ind, (dim_in, dim_out) in enumerate(in_out):\n",
    "            is_last = ind >= (len(in_out) - 1)\n",
    "            down_modules.append(nn.ModuleList([\n",
    "                ConditionalResidualBlock1D(\n",
    "                    dim_in, dim_out, cond_dim=cond_dim,\n",
    "                    kernel_size=kernel_size, n_groups=n_groups),\n",
    "                ConditionalResidualBlock1D(\n",
    "                    dim_out, dim_out, cond_dim=cond_dim,\n",
    "                    kernel_size=kernel_size, n_groups=n_groups),\n",
    "                Downsample1d(dim_out) if not is_last else nn.Identity()\n",
    "            ]))\n",
    "\n",
    "        up_modules = nn.ModuleList([])\n",
    "        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):\n",
    "            is_last = ind >= (len(in_out) - 1)\n",
    "            up_modules.append(nn.ModuleList([\n",
    "                ConditionalResidualBlock1D(\n",
    "                    dim_out*2, dim_in, cond_dim=cond_dim,\n",
    "                    kernel_size=kernel_size, n_groups=n_groups),\n",
    "                ConditionalResidualBlock1D(\n",
    "                    dim_in, dim_in, cond_dim=cond_dim,\n",
    "                    kernel_size=kernel_size, n_groups=n_groups),\n",
    "                Upsample1d(dim_in) if not is_last else nn.Identity()\n",
    "            ]))\n",
    "\n",
    "        final_conv = nn.Sequential(\n",
    "            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),\n",
    "            nn.Conv1d(start_dim, input_dim, 1),\n",
    "        )\n",
    "\n",
    "        self.diffusion_step_encoder = diffusion_step_encoder\n",
    "        self.up_modules = up_modules\n",
    "        self.down_modules = down_modules\n",
    "        self.final_conv = final_conv\n",
    "\n",
    "        print(\"number of parameters: {:e}\".format(\n",
    "            sum(p.numel() for p in self.parameters()))\n",
    "        )\n",
    "\n",
    "    def forward(self,\n",
    "            sample: torch.Tensor,\n",
    "            timestep: Union[torch.Tensor, float, int],\n",
    "            global_cond=None):\n",
    "        \"\"\"\n",
    "        x: (B,T,input_dim)\n",
    "        timestep: (B,) or int, diffusion step\n",
    "        global_cond: (B,global_cond_dim)\n",
    "        output: (B,T,input_dim)\n",
    "        \"\"\"\n",
    "        # (B,T,C)\n",
    "        sample = sample.moveaxis(-1,-2)\n",
    "        # (B,C,T)\n",
    "\n",
    "        # 1. time\n",
    "        timesteps = timestep\n",
    "        if not torch.is_tensor(timesteps):\n",
    "            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n",
    "            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)\n",
    "        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:\n",
    "            timesteps = timesteps[None].to(sample.device)\n",
    "        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n",
    "        timesteps = timesteps.expand(sample.shape[0])\n",
    "\n",
    "        global_feature = self.diffusion_step_encoder(timesteps)\n",
    "\n",
    "        if global_cond is not None:\n",
    "            global_feature = torch.cat([\n",
    "                global_feature, global_cond\n",
    "            ], axis=-1)\n",
    "\n",
    "        x = sample\n",
    "        h = []\n",
    "        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):\n",
    "            x = resnet(x, global_feature)\n",
    "            x = resnet2(x, global_feature)\n",
    "            h.append(x)\n",
    "            x = downsample(x)\n",
    "\n",
    "        for mid_module in self.mid_modules:\n",
    "            x = mid_module(x, global_feature)\n",
    "\n",
    "        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):\n",
    "            x = torch.cat((x, h.pop()), dim=1)\n",
    "            x = resnet(x, global_feature)\n",
    "            x = resnet2(x, global_feature)\n",
    "            x = upsample(x)\n",
    "\n",
    "        x = self.final_conv(x)\n",
    "\n",
    "        # (B,C,T)\n",
    "        x = x.moveaxis(-1,-2)\n",
    "        # (B,T,C)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae_ckpt_path = \"vae.pt\"\n",
    "diffusion_ckpt_path = \"diff.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "policy = MLP(\n",
    "    n_in=10, \n",
    "    n_out=2, \n",
    "    hidden_layers=[256, 256], \n",
    "    no_weights=True\n",
    ")\n",
    "vae = VAEHyperNetModel(\n",
    "    policy, \n",
    "    state_dim=10, \n",
    "    action_dim=2, \n",
    "    latent_dim=256, \n",
    "    vae_decoder_size='l',\n",
    "    traj_len=16, \n",
    ").to(device)\n",
    "\n",
    "noise_pred_net = ConditionalUnet1D(\n",
    "    input_dim=1,\n",
    "    global_cond_dim=10 + 1,\n",
    "    diffusion_step_embed_dim=256,\n",
    "    down_dims=[32, 64, 128],\n",
    "    kernel_size=5\n",
    ")\n",
    "\n",
    "num_diffusion_iters = 100\n",
    "noise_scheduler = DDPMScheduler(\n",
    "    num_train_timesteps=num_diffusion_iters,\n",
    "    # the choise of beta schedule has big impact on performance\n",
    "    # we found squared cosine works the best\n",
    "    beta_schedule='squaredcos_cap_v2',\n",
    "    # clip output to [-1,1] to improve stability\n",
    "    clip_sample=True,\n",
    "    # our network predicts noise (instead of denoised action)\n",
    "    prediction_type='epsilon'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load(vae_ckpt_path)\n",
    "vae.load_state_dict(checkpoint['model_state_dict'])\n",
    "for param in vae.parameters():\n",
    "    param.requires_grad = False\n",
    "noise_pred_net.load_state_dict(torch.load(diffusion_ckpt_path)['state_dict'])\n",
    "\n",
    "vae = vae.to(device)\n",
    "noise_pred_net = noise_pred_net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = {\n",
    "    \"obs\": {\n",
    "        \"max\": np.array([496.14618, 510.9579, 439.9153, 485.6641, 6.2830877], dtype=np.float32),\n",
    "        \"min\": np.array([13.456424, 32.938293, 57.471767, 108.27995, 0.00021559125], dtype=np.float32),\n",
    "    },\n",
    "    \"action\": {\n",
    "        \"max\": np.array([511.0, 511.0], dtype=np.float32),\n",
    "        \"min\": np.array([12.0, 25.0], dtype=np.float32),\n",
    "    },\n",
    "}\n",
    "\n",
    "env = PushTEnv(\n",
    "    seed=1\n",
    ")\n",
    "env = NormalizeActionWrapper(env, stats)\n",
    "env = SpaceConversionWrapper(env)\n",
    "env = StateStackWrapper(env, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs = env.reset()[0]\n",
    "obs = obs[None]\n",
    "B = 1\n",
    "number_of_perturbations = 10\n",
    "noise_scale = 10\n",
    "latent_scaling_factor = 0.18215\n",
    "task_id = torch.tensor([[0]]).to(device)\n",
    "sleep_dt = 0.05\n",
    "action_horizon = 16\n",
    "step_idx = 0\n",
    "max_traj_len = 256\n",
    "\n",
    "state = torch.tensor(obs).to(device).to(torch.float32)\n",
    "done = np.zeros(B, dtype=bool)\n",
    "steps_to_perturb = [\n",
    "    # choose 10 steps to apply perturbation, without repeating\n",
    "    np.random.choice(256, number_of_perturbations, replace=False)\n",
    "    for _ in range(B)\n",
    "]\n",
    "env.setattr(\"move_t_range\", noise_scale)\n",
    "env.setattr(\"iter_indices_to_move_t\", steps_to_perturb[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_id.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "while not np.all(done):\n",
    "    sample = torch.randn(\n",
    "        (B, 256),\n",
    "        device=device\n",
    "    )\n",
    "    sample = rearrange(sample, \"b (t s) -> b t s\", t = 256)\n",
    "    noise_scheduler.set_timesteps(num_diffusion_iters)\n",
    "    obs_cond = torch.tensor(obs, dtype=torch.float32, device=device)\n",
    "    # obs_cond = rearrange(dum, \"b (t s) -> b t s\", t = 2)\n",
    "    # obs_cond = obs_cond.flatten(start_dim=1)\n",
    "    obs_cond = torch.cat([obs_cond, task_id], dim=1)\n",
    "    for k in noise_scheduler.timesteps:\n",
    "        with torch.no_grad():\n",
    "            noise_pred = noise_pred_net(\n",
    "                sample=sample,\n",
    "                timestep=k,\n",
    "                global_cond=obs_cond\n",
    "            )\n",
    "            # noise_preds_norm = noise_pred.norm(dim=-2).mean().item()\n",
    "            # noise_pred_norms.append(noise_preds_norm)\n",
    "        sample = noise_scheduler.step(\n",
    "            model_output=noise_pred,\n",
    "            timestep=k,\n",
    "            sample=sample\n",
    "        ).prev_sample\n",
    "\n",
    "    sample = rearrange(sample, \"b t s -> b (t s)\") / latent_scaling_factor\n",
    "    policy_weights = vae.decode(sample)\n",
    "\n",
    "    time.sleep(10*sleep_dt)\n",
    "\n",
    "    for j in range(action_horizon):\n",
    "        if np.all(done):\n",
    "            break\n",
    "\n",
    "        if state.shape[0] == 1:\n",
    "            pred_acts = policy(state[0], policy_weights).detach().cpu().numpy()\n",
    "        else:\n",
    "            pred_acts = torch.stack([\n",
    "                        policy(state[i], policy_weights[i]) for i in range(B)\n",
    "                    ]).detach().cpu().numpy()         \n",
    "\n",
    "       \n",
    "        next_obs, reward, done_env, trunc, info = env.step(pred_acts)\n",
    "\n",
    "        env.render()\n",
    "        time.sleep(sleep_dt)\n",
    "\n",
    "        next_obs = next_obs[None]\n",
    "        reward = np.array([reward])\n",
    "        done_env = np.array([done_env])\n",
    "\n",
    "        state = torch.tensor(next_obs).to(device).to(torch.float32)\n",
    "\n",
    "        next_obs = np.array(next_obs)\n",
    "        reward = np.array(reward)\n",
    "        done_env = np.array(done_env)\n",
    "\n",
    "\n",
    "        done = np.logical_or(done, done_env)\n",
    "\n",
    "        ### could not find success in info\n",
    "\n",
    "        # update next obs for the env that hasn't ended\n",
    "        step_idx += 1\n",
    "        if step_idx >= max_traj_len:\n",
    "            done[:] = True\n",
    "            break\n",
    "    if np.all(done):\n",
    "        break\n",
    "    ### What does the env update logic mean in diffusion?\n",
    "    ##### issue ##################\n",
    "    obs = next_obs\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lwd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
