{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "# VMAS environment components\n",
    "from vmas.simulator.environment import Environment\n",
    "from smart_grid import SmartGridScenario  # <-- The scenario class you provided\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate your scenario\n",
    "scenario = SmartGridScenario()\n",
    "\n",
    "# Example: specify relevant kwargs here, like number of agents, building types, etc.\n",
    "# For instance:\n",
    "kwargs = {            # or however many you want\n",
    "    \"episode_length\": 80,        # override if needed\n",
    "    \"verbose\": False,\n",
    "    \"building_types\": [5, 1, 1, 1, 5]  # or whatever building types you have\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Typically, VMAS provides 'make_env' or an 'Environment' constructor.\n",
    "# We'll show the direct approach using Environment(...).\n",
    "\n",
    "device = torch.device(\"cpu\")  # or \"cuda\" if you want GPU\n",
    "\n",
    "# For batch dimension, say we want batch_dim=1 (single environment) or a bigger batch.\n",
    "batch_dim = 32\n",
    "\n",
    "env = Environment(\n",
    "    scenario=scenario,\n",
    "    max_steps=kwargs.get(\"episode_length\", 80),\n",
    "    device=device, \n",
    "    **kwargs\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Observations Type: <class 'list'>\n",
      "Agent 0 obs shape: torch.Size([32, 4])\n",
      "tensor([[0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000]])\n",
      "Agent 1 obs shape: torch.Size([32, 4])\n",
      "tensor([[0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000]])\n",
      "Agent 2 obs shape: torch.Size([32, 4])\n",
      "tensor([[0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000]])\n",
      "Agent 3 obs shape: torch.Size([32, 4])\n",
      "tensor([[0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000],\n",
      "        [0.0000, 0.5676, 3.0250, 0.0000]])\n",
      "Agent 4 obs shape: torch.Size([32, 4])\n",
      "tensor([[0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000],\n",
      "        [0.0000, 1.1353, 3.0250, 0.0000]])\n"
     ]
    }
   ],
   "source": [
    "# Reset the environment\n",
    "observations = env.reset()\n",
    "# 'observations' is typically a list or Tensor of shape [batch_dim, n_agents, obs_dim], \n",
    "# or possibly a dict. Let's print out shape and content.\n",
    "\n",
    "print(\"Observations Type:\", type(observations))\n",
    "if isinstance(observations, torch.Tensor):\n",
    "    print(\"Observations Shape:\", observations.shape)\n",
    "    print(\"Observations:\", observations)\n",
    "else:\n",
    "    # If it's a list of Tensors, for example:\n",
    "    for i, obs in enumerate(observations):\n",
    "        print(f\"Agent {i} obs shape:\", obs.shape)\n",
    "        print(obs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Action space: Tuple(Box(-1.0, 1.0, (2,), float32), Box(-1.0, 1.0, (2,), float32), Box(-1.0, 1.0, (2,), float32), Box(-1.0, 1.0, (2,), float32), Box(-1.0, 1.0, (2,), float32))\n"
     ]
    }
   ],
   "source": [
    "# Typically, VMAS uses .action_space or something similar:\n",
    "try:\n",
    "    print(\"Action space:\", env.action_space)\n",
    "except AttributeError:\n",
    "    print(\"env.action_space not found; consider how actions are expected to be passed.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Next Observations shape: Varies (list)\n",
      "Rewards shape/type: <class 'list'> [tensor([-0.3521,  0.7491, -0.3592,  0.9280,  0.5963, -0.1119,  0.4324,  0.1824,\n",
      "         0.6232, -0.3733, -0.0750,  0.7803, -0.3258,  0.5254,  0.3459, -0.0165,\n",
      "         0.7310,  0.8827,  0.6394,  0.5455, -0.1093, -0.2562,  0.1944,  0.5190,\n",
      "         0.1790,  0.3840,  0.6190,  0.3962,  0.5764, -0.4139, -0.4453,  0.9460]), tensor([ 0.7119,  0.8690,  1.0771,  0.8353,  1.2657,  0.1022, -0.1696,  0.7997,\n",
      "         0.6131,  0.3057,  0.0974,  0.8985,  0.6679,  0.5175,  0.7967,  0.8353,\n",
      "         0.1476,  0.3018,  0.8834,  0.3519,  0.1054,  0.6458,  0.2776,  0.4093,\n",
      "         0.4541,  0.8235,  0.2270,  0.1942,  0.3916, -0.0805,  0.5282,  0.6534]), tensor([ 1.2363,  0.2559,  0.5107,  0.3184, -0.1950,  0.3430,  0.6164, -0.1363,\n",
      "        -0.0580,  0.1957,  0.1926,  0.2771,  0.2240,  0.4631,  0.4455,  0.5900,\n",
      "         1.2737,  0.1060,  1.1855,  0.7327,  0.4220, -0.2181,  1.1464, -0.0590,\n",
      "         0.3077,  0.7243,  0.3209,  0.1069,  0.5982,  0.4916, -0.0384,  0.7122]), tensor([-0.0502,  0.4573,  0.4620,  0.4439, -0.0127, -0.0555,  0.4401,  0.3315,\n",
      "         0.6507,  0.0581,  1.1169, -0.0250,  0.8994, -0.2040, -0.0844,  0.8769,\n",
      "         0.8889,  0.8797,  1.0395,  0.7170, -0.0783, -0.0198, -0.1737, -0.1071,\n",
      "         0.3256,  0.5783,  1.0190, -0.1787,  0.9434,  0.7530,  0.2232,  1.1187]), tensor([ 0.7207, -0.4179,  0.8278,  0.7374,  1.0322,  0.1689, -0.3324,  0.4958,\n",
      "         0.6043,  0.7623, -0.2975, -0.4748, -0.4378,  0.7042, -0.1626,  0.0082,\n",
      "        -0.2769,  0.7809,  0.4663,  0.0419, -0.1912,  0.6796, -0.3511,  0.0057,\n",
      "         0.0980, -0.3468,  0.5513,  0.8128, -0.0371,  0.3953,  0.0823, -0.2661])]\n",
      "Dones shape/type: <class 'torch.Tensor'> tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False])\n",
      "Infos: [{'battery': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'postponed': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'grid_cost': tensor([0.2850, 2.4453, 0.2710, 2.7962, 2.1455, 0.7561, 1.8240, 1.3335, 2.1982,\n",
      "        0.2435, 0.8286, 2.5065, 0.3366, 2.0064, 1.6542, 0.9434, 2.4097, 2.7073,\n",
      "        2.2300, 2.0458, 0.7614, 0.4732, 1.3572, 1.9939, 1.3269, 1.7290, 2.1901,\n",
      "        1.7530, 2.1065, 0.1638, 0.1023, 2.8314]), 'reward': tensor([-0.3521,  0.7491, -0.3592,  0.9280,  0.5963, -0.1119,  0.4324,  0.1824,\n",
      "         0.6232, -0.3733, -0.0750,  0.7803, -0.3258,  0.5254,  0.3459, -0.0165,\n",
      "         0.7310,  0.8827,  0.6394,  0.5455, -0.1093, -0.2562,  0.1944,  0.5190,\n",
      "         0.1790,  0.3840,  0.6190,  0.3962,  0.5764, -0.4139, -0.4453,  0.9460])}, {'battery': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'postponed': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'grid_cost': tensor([1.8844, 2.1926, 2.6009, 2.1266, 2.9709, 0.6884, 0.1551, 2.0567, 1.6907,\n",
      "        1.0875, 0.6789, 2.2505, 1.7981, 1.5031, 2.0507, 2.1266, 0.7775, 1.0800,\n",
      "        2.2208, 1.1781, 0.6946, 1.7548, 1.0324, 1.2908, 1.3787, 2.1034, 0.9331,\n",
      "        0.8689, 1.2561, 0.3300, 1.5240, 1.7696]), 'reward': tensor([ 0.7119,  0.8690,  1.0771,  0.8353,  1.2657,  0.1022, -0.1696,  0.7997,\n",
      "         0.6131,  0.3057,  0.0974,  0.8985,  0.6679,  0.5175,  0.7967,  0.8353,\n",
      "         0.1476,  0.3018,  0.8834,  0.3519,  0.1054,  0.6458,  0.2776,  0.4093,\n",
      "         0.4541,  0.8235,  0.2270,  0.1942,  0.3916, -0.0805,  0.5282,  0.6534])}, {'battery': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'postponed': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'grid_cost': tensor([2.9131, 0.9899, 1.4898, 1.1125, 0.1053, 1.1608, 1.6970, 0.2204, 0.3741,\n",
      "        0.8718, 0.8658, 1.0315, 0.9272, 1.3963, 1.3618, 1.6453, 2.9866, 0.6959,\n",
      "        2.8134, 1.9252, 1.3158, 0.0601, 2.7367, 0.3721, 1.0914, 1.9087, 1.1174,\n",
      "        0.6976, 1.6614, 1.4523, 0.4126, 1.8850]), 'reward': tensor([ 1.2363,  0.2559,  0.5107,  0.3184, -0.1950,  0.3430,  0.6164, -0.1363,\n",
      "        -0.0580,  0.1957,  0.1926,  0.2771,  0.2240,  0.4631,  0.4455,  0.5900,\n",
      "         1.2737,  0.1060,  1.1855,  0.7327,  0.4220, -0.2181,  1.1464, -0.0590,\n",
      "         0.3077,  0.7243,  0.3209,  0.1069,  0.5982,  0.4916, -0.0384,  0.7122])}, {'battery': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'postponed': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'grid_cost': tensor([0.3895, 1.3850, 1.3942, 1.3586, 0.4629, 0.3790, 1.3512, 1.1381, 1.7644,\n",
      "        0.6019, 2.6789, 0.4388, 2.2523, 0.0877, 0.3222, 2.2082, 2.2316, 2.2135,\n",
      "        2.5270, 1.8944, 0.3343, 0.4490, 0.1471, 0.2779, 1.1266, 1.6224, 2.4868,\n",
      "        0.1373, 2.3386, 1.9651, 0.9257, 2.6825]), 'reward': tensor([-0.0502,  0.4573,  0.4620,  0.4439, -0.0127, -0.0555,  0.4401,  0.3315,\n",
      "         0.6507,  0.0581,  1.1169, -0.0250,  0.8994, -0.2040, -0.0844,  0.8769,\n",
      "         0.8889,  0.8797,  1.0395,  0.7170, -0.0783, -0.0198, -0.1737, -0.1071,\n",
      "         0.3256,  0.5783,  1.0190, -0.1787,  0.9434,  0.7530,  0.2232,  1.1187])}, {'battery': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.]), 'postponed': tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 1.1351, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000]), 'grid_cost': tensor([2.3896, 0.1560, 2.5996, 2.4223, 3.0006, 1.3070, 0.3236, 1.9483, 2.1612,\n",
      "        2.4712, 0.3921, 0.0443, 0.1170, 2.3571, 0.6568, 0.9918, 0.4326, 2.5076,\n",
      "        1.8905, 1.0579, 0.6006, 2.3090, 0.2870, 0.9869, 1.1681, 0.2954, 2.0573,\n",
      "        2.5703, 0.9031, 1.7511, 1.1372, 0.4538]), 'reward': tensor([ 0.7207, -0.4179,  0.8278,  0.7374,  1.0322,  0.1689, -0.3324,  0.4958,\n",
      "         0.6043,  0.7623, -0.2975, -0.7235, -0.4378,  0.7042, -0.1626,  0.0082,\n",
      "        -0.2769,  0.7809,  0.4663,  0.0419, -0.1912,  0.6796, -0.3511,  0.0057,\n",
      "         0.0980, -0.3468,  0.5513,  0.8128, -0.0371,  0.3953,  0.0823, -0.2661])}]\n"
     ]
    }
   ],
   "source": [
    "# Suppose your environment expects each agent to take 2D continuous actions.\n",
    "# We'll create random actions for each agent in the batch.\n",
    "num_agents = len(observations)\n",
    "action_dim = 2  # from reading the scenario's code (two continuous dims: from_grid, from_battery)\n",
    "\n",
    "# We'll do a loop or vector creation:\n",
    "random_actions = []\n",
    "for _ in range(num_agents):\n",
    "    # random between 0 and 1, for example\n",
    "    act = torch.rand(batch_dim, action_dim)\n",
    "    random_actions.append(act)\n",
    "\n",
    "# Step the environment\n",
    "next_observations, rewards, dones, infos = env.step(random_actions)\n",
    "\n",
    "print(\"Next Observations shape:\", next_observations.shape \n",
    "      if isinstance(next_observations, torch.Tensor) else \"Varies (list)\")\n",
    "print(\"Rewards shape/type:\", type(rewards), rewards)\n",
    "print(\"Dones shape/type:\", type(dones), dones)\n",
    "print(\"Infos:\", infos)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: rew=[tensor([ 0.3323,  0.2853,  0.5602,  0.5133,  0.6647,  0.0801,  0.6839,  0.7550,\n",
      "         0.8704,  0.2973,  0.9952,  0.0349, -0.2726,  0.6680,  0.9121,  0.4249,\n",
      "         0.4476,  0.3711,  0.4470,  0.2299,  0.6516,  0.9401, -0.2477, -0.0710,\n",
      "         1.0217,  0.1164,  0.5840, -0.2002,  0.3268, -0.3435, -0.2020, -0.1940]), tensor([ 0.3613,  0.6515,  1.2097,  0.4120,  1.1708,  0.1263,  0.3848,  0.0776,\n",
      "        -0.0413,  1.2752,  0.9398, -0.0231,  0.4895,  1.1582,  0.4352,  0.1944,\n",
      "        -0.1086, -0.0881,  0.9697,  0.0909, -0.0493, -0.0085,  1.0944,  0.1695,\n",
      "         0.8122,  0.4493,  1.2810,  0.1938,  1.1201,  1.1529,  0.2551,  0.8766]), tensor([ 0.2090, -0.2243,  0.7440,  0.4382,  0.7256,  0.5093, -0.0221,  0.8412,\n",
      "         0.7296, -0.0543,  0.3934,  0.9250,  0.6834, -0.0977,  0.4221,  0.2949,\n",
      "         0.7782,  1.1143,  0.0595, -0.1423,  0.9719,  0.1291,  0.0910,  0.9848,\n",
      "         0.8180,  0.6786,  1.1219,  0.0704,  0.4205,  0.2547,  0.1453,  0.6204]), tensor([ 1.0056,  0.9585, -0.0308,  0.4850,  1.1758,  0.6762,  1.1229, -0.0083,\n",
      "        -0.1421, -0.2412,  0.0499, -0.0564,  0.5359,  0.0333,  1.1344,  0.6774,\n",
      "         0.0809,  0.9801, -0.2320,  0.6395,  0.0036,  0.4556,  0.1499,  0.2684,\n",
      "         0.9332,  0.1438, -0.2115,  0.6792,  1.0415,  0.0942, -0.0660,  0.1664]), tensor([ 0.3933, -0.4888,  0.5761, -0.3672, -0.2664,  0.2828,  0.3397, -0.2389,\n",
      "        -0.0037,  0.9507,  0.7487, -0.2114, -0.1432, -0.4392,  0.1437, -0.3297,\n",
      "        -0.2870,  0.1159, -0.4830, -0.0036,  0.9750,  0.1081,  0.1723, -0.1070,\n",
      "         0.5545,  0.6101,  0.2293, -0.4196,  0.2637, -0.1259,  0.0979, -0.3523])], done=tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False])\n",
      "Step 1: rew=[tensor([-0.1614,  0.4681,  0.7450, -0.4838, -0.3535, -0.0394,  0.9735,  0.9331,\n",
      "        -0.1370,  0.2474,  0.9336,  0.4300,  0.3856, -0.3187,  0.4572,  0.9869,\n",
      "         0.3032,  0.5268,  0.7464, -0.2847,  0.4121,  0.5113,  0.2814,  0.3054,\n",
      "        -0.4736,  0.2971, -0.1646,  0.0980,  0.2745,  0.0190,  0.9546,  0.9316]), tensor([ 0.7532,  0.1584,  0.1088,  0.6299,  0.8437,  1.1226, -0.1977, -0.1762,\n",
      "         0.8495, -0.2152,  1.1062,  0.7110, -0.0177,  0.4279,  0.4761,  1.0664,\n",
      "         0.7163,  0.0406,  0.2166,  0.2566,  1.2361,  0.4422, -0.2363, -0.1037,\n",
      "         0.5718,  1.0462,  0.3075, -0.2123,  1.0806,  0.3499,  0.5223,  0.1246]), tensor([ 0.7745,  0.4133, -0.1845,  0.5844,  0.8161,  1.2900,  1.0356,  0.0703,\n",
      "        -0.1794, -0.2053, -0.1524,  0.7780,  0.2645,  1.2341,  1.1389,  0.3400,\n",
      "         1.1655,  0.4086,  0.6058,  0.3626,  0.2309,  0.1116,  0.9981,  0.2003,\n",
      "         1.2874,  0.9228,  1.2861,  0.1407,  1.0995,  0.6366,  0.3615,  0.8238]), tensor([ 0.7531,  0.1250,  1.1045,  0.0868,  0.8144,  0.6877, -0.2473,  0.8435,\n",
      "         0.9242, -0.2168,  0.1797,  0.8388,  0.0336,  0.0902, -0.0604,  0.6972,\n",
      "        -0.2132,  0.7598,  0.3771,  0.2509,  0.1800,  0.8715,  1.2488,  0.4827,\n",
      "         1.2134,  1.2423,  0.4264,  0.9127, -0.1471,  0.9343,  0.0275,  0.9145]), tensor([-0.2117, -0.0335, -0.2729,  0.4553,  0.1349,  1.0268,  0.2945, -0.0093,\n",
      "        -0.3161,  0.0438,  0.8308,  0.7637, -0.3252,  0.3708, -0.1149,  0.0747,\n",
      "         0.4752,  0.0336,  0.2119,  0.1921,  1.0129,  0.2915, -0.0399,  0.6929,\n",
      "         0.5194,  0.1329,  0.4843,  0.9552,  0.8111,  0.3836, -0.4836,  0.1551])], done=tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False])\n",
      "Step 2: rew=[tensor([-0.3918, -0.2992,  0.8292, -0.2562, -0.4454,  0.2926, -0.1933,  0.9971,\n",
      "         0.8717, -0.1649, -0.2584,  0.6071,  0.5910,  1.0034, -0.2482,  0.8567,\n",
      "         0.1976,  0.7634, -0.4952,  0.9985,  0.5657,  0.1661,  0.7434,  0.2709,\n",
      "         0.0582,  0.1012,  0.4590,  0.8817,  0.1178,  0.2074,  0.8877,  0.2500]), tensor([ 0.4894,  1.1621,  0.3564,  0.2422,  0.3947,  0.7559, -0.1110,  1.2219,\n",
      "         1.1547,  1.1220,  1.2907,  0.2793, -0.0138,  0.3711,  0.9140,  0.5941,\n",
      "         1.2498,  0.8784,  1.2923,  1.1794,  1.2241,  0.9257,  0.3938, -0.0551,\n",
      "         0.1380,  0.4081,  0.7625,  0.4011,  1.1730,  0.5730,  0.8290,  1.0491]), tensor([ 0.6264,  1.2900,  0.1136,  1.0319,  0.0186, -0.0910,  0.7181, -0.0386,\n",
      "         0.3184,  0.4775,  0.6483,  0.8794,  0.9091,  0.2375,  0.0973,  1.0585,\n",
      "         0.6668, -0.1063,  1.1502,  0.3294,  1.1787,  1.2251, -0.1315, -0.1753,\n",
      "         0.4417,  0.7442, -0.0778,  0.2664,  0.7056,  0.4220, -0.0412,  1.0180]), tensor([ 1.0414,  0.5626,  0.0874,  0.4996, -0.0023, -0.1797,  0.8215,  1.2655,\n",
      "         0.7318,  0.5264,  0.2505, -0.2188,  0.6542, -0.1583,  0.6254,  0.5150,\n",
      "        -0.0270, -0.0341,  0.6118,  0.6122, -0.1151,  0.2373,  0.2003,  0.2596,\n",
      "         0.6522, -0.1940, -0.1521,  0.0907,  1.1686, -0.0064,  0.4157,  0.3907]), tensor([ 0.7099,  0.1672,  0.5077,  0.9973,  0.8130, -0.4110,  1.0366, -0.0896,\n",
      "        -0.3368,  0.0704, -0.0816,  0.6562,  0.2897,  0.5585,  0.8236,  0.9228,\n",
      "        -0.4804,  0.8585, -0.2060, -0.3844,  0.0319,  0.3053,  0.3025,  0.2228,\n",
      "         0.5011,  0.7149,  0.7086,  0.5944, -0.3857,  0.1662, -1.0432,  0.3340])], done=tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False])\n",
      "Step 3: rew=[tensor([ 0.2215,  0.1001,  0.3617,  0.4511,  0.8840,  0.1417,  0.0189,  0.7476,\n",
      "        -0.1564,  0.4828,  0.1383,  0.1984,  0.1620, -0.1689, -0.2071,  0.3993,\n",
      "         0.0561,  0.2087, -1.1081, -0.3164, -0.0247,  0.1489,  0.8478,  0.0919,\n",
      "         0.2709,  0.4338,  0.7979,  0.7230,  0.9580,  0.3205,  1.0444,  0.6234]), tensor([ 0.0755,  1.0085,  0.2567, -0.0467,  1.0057,  1.0337,  0.7050,  0.3998,\n",
      "         0.4433, -0.0236, -0.0746,  0.1066,  1.1411,  0.3633, -0.1938,  0.7004,\n",
      "         1.0117,  0.5696,  1.1052,  0.5015,  0.6515,  0.4937,  0.9875,  0.5609,\n",
      "         0.6097,  0.7663,  1.2252,  0.8527, -0.1287, -0.0023,  1.0891,  1.1207]), tensor([ 1.1857,  0.2410,  0.1812,  0.7091,  0.7812,  0.8351,  0.9864,  0.5503,\n",
      "         1.0422,  1.2426,  0.9374, -0.0865,  0.2083, -0.0900,  0.4459,  1.1964,\n",
      "         0.4545,  0.9169,  0.1643,  0.6067,  0.6875,  0.9366,  0.7633,  0.3630,\n",
      "         0.1187,  1.0461,  0.7916,  0.2788, -0.0188,  0.2845,  0.9076,  0.7246]), tensor([ 1.0530,  0.5909,  1.1013,  1.0582,  0.7341,  0.6526,  0.5247,  0.7294,\n",
      "         1.0863,  0.0205,  0.6040,  0.9865,  0.3132,  1.2638,  0.1714, -0.1133,\n",
      "         1.1444,  0.1752, -0.1321,  0.1255,  0.0069,  0.1555,  1.0183,  0.7457,\n",
      "         0.6379,  0.9029,  0.1080,  0.3193,  0.8646,  1.0187,  1.1010,  0.4561]), tensor([ 0.0204, -0.2495,  0.1513, -0.3715, -0.0575,  0.9659,  0.7239,  0.0470,\n",
      "         0.6307, -0.0829, -0.2567,  0.7015, -0.3980,  0.5536,  0.7649, -0.0654,\n",
      "        -0.4543, -0.0594,  1.0389,  0.2617, -0.4164,  0.6299,  0.7797, -0.1274,\n",
      "        -0.4447,  1.0095,  0.8045, -0.3034,  1.0275,  0.6068, -0.3370,  0.3069])], done=tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False])\n",
      "Step 4: rew=[tensor([ 0.8597, -0.3257,  0.2736,  0.1313, -0.4018,  0.4217,  0.3328,  0.7649,\n",
      "        -0.1966,  0.6470, -0.2725,  0.7410,  0.5086,  0.3505, -0.1760,  0.1104,\n",
      "         0.8604, -0.4077,  0.9454, -0.4524,  0.5042, -0.4583, -0.2228, -0.4608,\n",
      "        -0.4245,  0.7889,  0.9128, -0.2714,  0.0815,  0.8891,  0.8360, -0.0510]), tensor([ 1.1258,  0.5863,  0.1538, -0.0550,  0.7930, -0.2177,  0.2321,  0.1927,\n",
      "         0.8562, -0.1254,  0.6997,  0.7106,  0.9416,  0.8150,  0.4907,  0.9845,\n",
      "         0.9311,  0.3435,  1.2907,  0.7933,  0.5645,  0.3119,  0.3186, -0.0890,\n",
      "         0.2744, -0.0364,  0.3275,  0.0950,  1.1218,  0.5180,  0.9220, -0.1974]), tensor([ 0.8370,  1.2925,  0.1959,  1.1189, -0.1273,  1.2049,  1.2455,  0.4422,\n",
      "         0.7146, -0.2282,  0.2291,  0.3600,  0.3792,  0.2189,  0.6950, -0.1473,\n",
      "         0.8329, -0.1572,  0.2893,  0.4680, -0.0058,  0.7562,  0.5121,  0.1482,\n",
      "         0.9570,  0.1037,  0.5918,  0.7165,  0.3971,  1.0365,  0.9067,  0.4123]), tensor([ 0.9565,  1.1619,  0.1406,  0.6029,  1.1029,  1.2825, -0.1309,  0.9276,\n",
      "         0.3585, -0.1104, -0.0390,  1.1668,  1.0398,  0.5824,  0.1806,  0.8400,\n",
      "         0.2706,  0.1267,  0.3525,  0.0364,  0.5748,  1.2643,  0.0664,  0.2417,\n",
      "         1.0385,  0.7968,  0.5726,  1.2721,  0.8933,  0.4234,  0.4350,  0.5177]), tensor([-0.2733,  0.6879, -0.4448,  0.2282,  0.7588, -0.3218,  0.4322,  0.9110,\n",
      "         0.9429,  0.1536,  0.6499, -0.0344,  0.4018,  0.1781, -0.1224, -0.1521,\n",
      "         0.8423,  1.0121,  0.0460,  0.0077, -0.4681,  0.7047,  0.2717, -0.2754,\n",
      "         0.4763,  0.7841,  0.8223,  0.5805,  0.2424, -0.4223,  0.6448,  0.7628])], done=tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False])\n"
     ]
    }
   ],
   "source": [
    "obs = env.reset()\n",
    "for t in range(5):\n",
    "    # Make random actions\n",
    "    random_actions = [torch.rand(batch_dim, action_dim) for _ in range(num_agents)]\n",
    "    \n",
    "    obs, rew, done, info = env.step(random_actions)\n",
    "    print(f\"Step {t}: rew={rew}, done={done}\")\n",
    "    # If 'done' is True, you could break early or reset again as needed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agent 0 battery charge: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 0 postponed demand: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0084, 0.0000, 0.3035, 0.0000, 0.4275, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000])\n",
      "Agent 1 battery charge: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 1 postponed demand: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 2 battery charge: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 2 postponed demand: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.1041, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000])\n",
      "Agent 3 battery charge: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 3 postponed demand: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 4 battery charge: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "Agent 4 postponed demand: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.7942, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000])\n"
     ]
    }
   ],
   "source": [
    "for i, agent in enumerate(env.world.agents):\n",
    "    print(f\"Agent {i} battery charge:\", agent.state.battery_charge)\n",
    "    print(f\"Agent {i} postponed demand:\", agent.state.postponed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "marl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
