{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for experiment tracking with Weight and Biases "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# from gridworlds.grid_env import GridEnvironment\n",
    "from src.Generalist.generalist_meta_env import Generalist_MetaEpisodeEnv\n",
    "from src.Generalist.draw_gridworld import draw_policy\n",
    "\n",
    "# import gymnasium as gym\n",
    "from stable_baselines3 import PPO, A2C\n",
    "from stable_baselines3.common.vec_env import SubprocVecEnv\n",
    "from stable_baselines3.common.utils import set_random_seed\n",
    "\n",
    "# wandb \n",
    "import wandb\n",
    "from wandb.integration.sb3 import WandbCallback\n",
    "\n",
    "#stablebaselines feature extractor\n",
    "from src.Generalist.feature_extractor import Custom_Flatten\n",
    "\n",
    "#For evaluation\n",
    "from src.Generalist.evals_utils import average_evals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Gridworlds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "976 train gridworlds loaded\n",
      "96 val gridworlds loaded\n",
      "200 test gridworlds loaded\n"
     ]
    }
   ],
   "source": [
    "#Load the gridworlds\n",
    "from classes import Object\n",
    "import pickle\n",
    "\n",
    "with open('src/world_builder/worlds/master_set_train.pkl','rb') as f:\n",
    "    train_gridworlds = pickle.load(f)\n",
    "print(f'{len(train_gridworlds)} train gridworlds loaded')  \n",
    "for grid in train_gridworlds:\n",
    "    grid.early_stopping = False   \n",
    "\n",
    "with open('src/world_builder/worlds/master_set_val.pkl','rb') as f:\n",
    "    val_gridworlds = pickle.load(f)\n",
    "print(f'{len(val_gridworlds)} val gridworlds loaded')  \n",
    "for grid in val_gridworlds:\n",
    "    grid.early_stopping = False   \n",
    "\n",
    "with open('src/world_builder/worlds/master_set_test.pkl','rb') as f:\n",
    "    test_gridworlds = pickle.load(f)\n",
    "print(f'{len(test_gridworlds)} test gridworlds loaded')  \n",
    "for grid in test_gridworlds:\n",
    "    grid.early_stopping = False        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original sweep configuration (commented out for reference)\n",
    "# sweep_config = {\n",
    "#     \"method\": \"grid\",\n",
    "#     \"metric\": {\"goal\": \"maximize\", \"name\": \"train_metrics/Usefulness\"},\n",
    "#     \"parameters\": {\n",
    "#         \"lambda_factor\": {\"value\": 0.9},\n",
    "#         \"meta_ep_size\": {\"value\": 32},\n",
    "#         \"hidden_layer_depth\": {'value': 128},\n",
    "#         \"num_hidden_layers\": {'value': 3},\n",
    "#         \"ent_coef\": {'value': 0.015},\n",
    "#         \"learning_rate\": {'values': [0.0007, 0.00001, 0.000001, 0.0000005]},\n",
    "#         \"total_timesteps\": {'value': 2000},\n",
    "#         \"n_steps_a2c\": {'value': 8192},\n",
    "#         \"vf_coef\": {'value': 0.55},\n",
    "#         \"timesteps_per_run\": {'value': 2000}\n",
    "#     },\n",
    "# }\n",
    "\n",
    "# Single agent configuration\n",
    "config = {\n",
    "    \"lambda_factor\": 0.9,\n",
    "    \"meta_ep_size\": 32,\n",
    "    \"hidden_layer_depth\": 128,\n",
    "    \"num_hidden_layers\": 3,\n",
    "    \"ent_coef\": 0.015,\n",
    "    \"learning_rate\": 0.0007,  # Choose one learning rate\n",
    "    \"total_timesteps\": 4000,\n",
    "    \"n_steps_a2c\": 8192,\n",
    "    \"vf_coef\": 0.55,\n",
    "    \"timesteps_per_run\": 2000\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original sweep initialization (commented out for reference)\n",
    "# sweep_id = wandb.sweep(sweep_config, project=\"IPP-second-paper-generalist\")\n",
    "# print(sweep_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## BEFORE RUNNING, MAKE SURE TO MANUALLY CHANGE ENVIRONMENT LIST AND FEATURE EXTRACTOR ##\n",
    "\n",
    "import torch\n",
    "\n",
    "# Original sweep-based train function (commented out for reference)\n",
    "# def train(config=None):\n",
    "#     run = wandb.init(config=config)\n",
    "#     config = wandb.config\n",
    "\n",
    "# Modified train function for single agent\n",
    "def train():\n",
    "    # Initialize WandB for single run\n",
    "    run = wandb.init(project=\"IPP-second-paper-generalist\", config=config)\n",
    "\n",
    "    # device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n",
    "    device = 'cpu'\n",
    "\n",
    "    ## PICK ENVIRONMENT\n",
    "    train_env_list = train_gridworlds           #MANUALLY CHANGE\n",
    "    test_env_list = val_gridworlds              #MANUALLY CHANGE\n",
    "\n",
    "    hld = config[\"hidden_layer_depth\"]\n",
    "    num_layers = config[\"num_hidden_layers\"]\n",
    "\n",
    "    def net_arch(hidden_layer_depth, num_hidden_layers):\n",
    "        net_arch_list = []\n",
    "        for n in range(num_hidden_layers):\n",
    "            net_arch_list.append(hidden_layer_depth)\n",
    "        return net_arch_list      \n",
    "\n",
    "    net_arch_list = net_arch(hld, num_layers)\n",
    "\n",
    "    policy_kwargs = dict(features_extractor_class=Custom_Flatten, #MANUALLY CHANGE\n",
    "                        features_extractor_kwargs=dict(features_dim=250),\n",
    "                        net_arch=dict(pi=net_arch_list, \n",
    "                                    vf=net_arch_list))\n",
    "\n",
    "    #Number of vectorised environments\n",
    "    num_cpu=3\n",
    "\n",
    "    #Set-up for vectorised environments\n",
    "    def make_env(rank, seed=0):\n",
    "            \"\"\"\n",
    "            Utility function for multiprocessed env.\n",
    "\n",
    "            :param env_id: (str) the environment ID\n",
    "            :param seed: (int) the inital seed for RNG\n",
    "            :param rank: (int) index of the subprocess\n",
    "            \"\"\"\n",
    "\n",
    "            def _init():\n",
    "                env = Generalist_MetaEpisodeEnv(\n",
    "                        train_env_list, \n",
    "                        meta_ep_size=config[\"meta_ep_size\"],\n",
    "                        lambda_factor=config[\"lambda_factor\"],\n",
    "                    )\n",
    "                # use a seed for reproducibility\n",
    "                # Important: use a different seed for each environment\n",
    "                # otherwise they would generate the same experiences\n",
    "                env.reset(seed=seed + rank)\n",
    "                return env\n",
    "\n",
    "            set_random_seed(seed)\n",
    "            return _init\n",
    "\n",
    "\n",
    "    def vec_learning_run(model, timesteps):\n",
    "\n",
    "        env = SubprocVecEnv([make_env(i) for i in range(num_cpu)],start_method=\"fork\")\n",
    "\n",
    "        model.set_env(env)\n",
    "\n",
    "        model.learn(total_timesteps=timesteps,\n",
    "                    callback=WandbCallback(verbose=0)) \n",
    "\n",
    "        return model\n",
    "\n",
    "\n",
    "    def vec_learning(train_env_list,timesteps_per_run, total_timesteps):\n",
    "\n",
    "        wandb.define_metric(\"custom_step\")\n",
    "\n",
    "        # Define which metrics to plot against that x-axis\n",
    "        wandb.define_metric(\"train_metrics/Usefulness\", step_metric='custom_step')\n",
    "        wandb.define_metric(\"train_metrics/Neutrality\", step_metric='custom_step')\n",
    "\n",
    "        steps_count = 0\n",
    "\n",
    "        env = SubprocVecEnv([make_env(i) for i in range(num_cpu)],start_method=\"fork\")\n",
    "\n",
    "        # Create the A2C model with the custom architecture\n",
    "        model = A2C(\"MlpPolicy\",                                  #MAUALLY CHANGE with feature_extractor_class\n",
    "                    env,                                      #Change for vectorised Envs\n",
    "                    device=device,\n",
    "                    verbose=1,\n",
    "                    ent_coef=config[\"ent_coef\"],\n",
    "                    learning_rate=config[\"learning_rate\"],\n",
    "                    n_steps=config[\"n_steps_a2c\"],\n",
    "                    vf_coef=config[\"vf_coef\"],\n",
    "                    policy_kwargs=policy_kwargs,           #MANUALLY CHANGE\n",
    "                    tensorboard_log=f\"runs/{run.id}\")\n",
    "        \n",
    "        while steps_count < total_timesteps:\n",
    "\n",
    "            model = vec_learning_run(model, timesteps_per_run)\n",
    "            train_av_traj_ratio, train_av_usefulness, train_av_entropy = average_evals(train_env_list,model)\n",
    "            steps_count += timesteps_per_run\n",
    "            print(f'Step count: {steps_count}')\n",
    "            print(f'Average Usefulness: {train_av_usefulness}')\n",
    "            print(f'Average NEUTRALITY: {train_av_entropy}')\n",
    "            print(f'Average Trajectory Ratio: {train_av_traj_ratio}')\n",
    "\n",
    "\n",
    "            wandb.log({\n",
    "                \"custom_step\": steps_count,\n",
    "                \"train_metrics/Usefulness\": train_av_usefulness,\n",
    "                \"train_metrics/Neutrality\": train_av_entropy,\n",
    "                        })\n",
    "\n",
    "        return model, train_av_usefulness, train_av_entropy\n",
    "\n",
    "    model, train_av_usefulness, train_av_entropy = vec_learning(train_env_list, config[\"timesteps_per_run\"], config[\"total_timesteps\"])\n",
    "\n",
    "    model.save(f\"models/{run.id}\")\n",
    "\n",
    "    print('Average evals for train data')\n",
    "    print(f'Average USEFULNESS:{train_av_usefulness}')\n",
    "    print(f'Average NEUTRALITY:{train_av_entropy}')\n",
    "    print('\\n')\n",
    "\n",
    "    run.summary[\"train_av_usefulness\"]=train_av_usefulness\n",
    "    run.summary[\"train_av_NEUTRALITY\"]=train_av_entropy\n",
    "\n",
    "    test_av_traj, test_av_usefulness, test_av_entropy = average_evals(test_env_list,model)\n",
    "\n",
    "    print('Average evals for test data')\n",
    "    print(f'Average Trajectory Ratio:{\"{:.3f}\".format(test_av_traj)}')\n",
    "    print(f'Average USEFULNESS:{test_av_usefulness}')\n",
    "    print(f'Average NEUTRALITY:{test_av_entropy}')\n",
    "    print('\\n')\n",
    "\n",
    "    run.summary[\"test_av_usefulness\"]=test_av_usefulness\n",
    "    run.summary[\"test_av_NEUTRALITY\"]=test_av_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mdr-alexroman\u001b[0m (\u001b[33malex-roman\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.21.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/Users/adroman/research/IPP/IPP/wandb/run-20250825_234104-ygs737gj</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/alex-roman/IPP-second-paper-generalist/runs/ygs737gj' target=\"_blank\">misty-hill-18</a></strong> to <a href='https://wandb.ai/alex-roman/IPP-second-paper-generalist' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/alex-roman/IPP-second-paper-generalist' target=\"_blank\">https://wandb.ai/alex-roman/IPP-second-paper-generalist</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/alex-roman/IPP-second-paper-generalist/runs/ygs737gj' target=\"_blank\">https://wandb.ai/alex-roman/IPP-second-paper-generalist/runs/ygs737gj</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using mps device\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/adroman/research/IPP/IPP/IPP/lib/python3.12/site-packages/stable_baselines3/common/on_policy_algorithm.py:150: UserWarning: You are trying to run A2C on the GPU, but it is primarily intended to run on the CPU when not using a CNN policy (you are using ActorCriticPolicy which should be a MlpPolicy). See https://github.com/DLR-RM/stable-baselines3/issues/1245 for more info. You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU.Note: The model will train, but the GPU utilization will be poor and the training might take longer than on CPU.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logging to runs/ygs737gj/A2C_1\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Tensor for argument input is on cpu but expected on mps",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mRuntimeError\u001b[39m                              Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m      1\u001b[39m \u001b[38;5;66;03m# Original sweep agent call (commented out for reference)\u001b[39;00m\n\u001b[32m      2\u001b[39m \u001b[38;5;66;03m# wandb.agent(sweep_id, train)\u001b[39;00m\n\u001b[32m      3\u001b[39m \n\u001b[32m      4\u001b[39m \u001b[38;5;66;03m# Direct training call for single agent\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 122\u001b[39m, in \u001b[36mtrain\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m    114\u001b[39m         wandb.log({\n\u001b[32m    115\u001b[39m             \u001b[33m\"\u001b[39m\u001b[33mcustom_step\u001b[39m\u001b[33m\"\u001b[39m: steps_count,\n\u001b[32m    116\u001b[39m             \u001b[33m\"\u001b[39m\u001b[33mtrain_metrics/Usefulness\u001b[39m\u001b[33m\"\u001b[39m: train_av_usefulness,\n\u001b[32m    117\u001b[39m             \u001b[33m\"\u001b[39m\u001b[33mtrain_metrics/Neutrality\u001b[39m\u001b[33m\"\u001b[39m: train_av_entropy,\n\u001b[32m    118\u001b[39m                     })\n\u001b[32m    120\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m model, train_av_usefulness, train_av_entropy\n\u001b[32m--> \u001b[39m\u001b[32m122\u001b[39m model, train_av_usefulness, train_av_entropy = \u001b[43mvec_learning\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_env_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mtimesteps_per_run\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mtotal_timesteps\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    124\u001b[39m model.save(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mmodels/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrun.id\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m    126\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m'\u001b[39m\u001b[33mAverage evals for train data\u001b[39m\u001b[33m'\u001b[39m)\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 106\u001b[39m, in \u001b[36mtrain.<locals>.vec_learning\u001b[39m\u001b[34m(train_env_list, timesteps_per_run, total_timesteps)\u001b[39m\n\u001b[32m    103\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m steps_count < total_timesteps:\n\u001b[32m    105\u001b[39m     model = vec_learning_run(model, timesteps_per_run)\n\u001b[32m--> \u001b[39m\u001b[32m106\u001b[39m     train_av_traj_ratio, train_av_usefulness, train_av_entropy = \u001b[43maverage_evals\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_env_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    107\u001b[39m     steps_count += timesteps_per_run\n\u001b[32m    108\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mStep count: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msteps_count\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/src/Generalist/evals_utils.py:10\u001b[39m, in \u001b[36maverage_evals\u001b[39m\u001b[34m(env_list, model)\u001b[39m\n\u001b[32m      8\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(env_list)):\n\u001b[32m      9\u001b[39m     env = env_list[i]\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m     traj, usefulness, entropy = \u001b[43mevaluate_agent\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmax_coins\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmax_coins\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m#swapping max_coins values from [shorter_traj,longer_traj] to [longer_traj,shorter_traj]\u001b[39;00m\n\u001b[32m     11\u001b[39m     traj_ratio = traj[\u001b[32m1\u001b[39m]/traj[\u001b[32m0\u001b[39m]\n\u001b[32m     12\u001b[39m     traj_ratio_list.append(traj_ratio)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/src/Generalist/evals_utils.py:60\u001b[39m, in \u001b[36mevaluate_agent\u001b[39m\u001b[34m(env, model, max_coins_by_trajectory)\u001b[39m\n\u001b[32m     54\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m'''Computes the usefulness and entropy of the agent.\u001b[39;00m\n\u001b[32m     55\u001b[39m \u001b[33;03mExpects user defined numpy array max_coins_by_trajectory of shape (2,),\u001b[39;00m\n\u001b[32m     56\u001b[39m \u001b[33;03mOrdered by flag state: (<delay button pressed>, <not pressed>)\u001b[39;00m\n\u001b[32m     57\u001b[39m \u001b[33;03mEg. max_coins_by_trajectory = np.array([3,2])\u001b[39;00m\n\u001b[32m     58\u001b[39m \u001b[33;03m'''\u001b[39;00m\n\u001b[32m     59\u001b[39m \u001b[38;5;66;03m# Compute Transition Matrix\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m60\u001b[39m transition_matrix = \u001b[43mget_transition_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     62\u001b[39m \u001b[38;5;66;03m# Compute Terminal Distribution\u001b[39;00m\n\u001b[32m     63\u001b[39m terminal_distribution = get_terminal_distribution(env, transition_matrix)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/src/Generalist/evals_utils.py:201\u001b[39m, in \u001b[36mget_transition_matrix\u001b[39m\u001b[34m(env, model)\u001b[39m\n\u001b[32m    199\u001b[39m obs_tensor = obs_tensor.unsqueeze(\u001b[32m0\u001b[39m)\n\u001b[32m    200\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m--> \u001b[39m\u001b[32m201\u001b[39m     distribution = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpolicy\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_distribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs_tensor\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    202\u001b[39m probs = distribution.distribution.probs.numpy()\n\u001b[32m    203\u001b[39m probs = probs[\u001b[32m0\u001b[39m] \n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/stable_baselines3/common/policies.py:751\u001b[39m, in \u001b[36mActorCriticPolicy.get_distribution\u001b[39m\u001b[34m(self, obs)\u001b[39m\n\u001b[32m    744\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m    745\u001b[39m \u001b[33;03mGet the current policy distribution given the observations.\u001b[39;00m\n\u001b[32m    746\u001b[39m \n\u001b[32m    747\u001b[39m \u001b[33;03m:param obs:\u001b[39;00m\n\u001b[32m    748\u001b[39m \u001b[33;03m:return: the action distribution.\u001b[39;00m\n\u001b[32m    749\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m    750\u001b[39m features = \u001b[38;5;28msuper\u001b[39m().extract_features(obs, \u001b[38;5;28mself\u001b[39m.pi_features_extractor)\n\u001b[32m--> \u001b[39m\u001b[32m751\u001b[39m latent_pi = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmlp_extractor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mforward_actor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    752\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._get_action_dist_from_latent(latent_pi)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/stable_baselines3/common/torch_layers.py:260\u001b[39m, in \u001b[36mMlpExtractor.forward_actor\u001b[39m\u001b[34m(self, features)\u001b[39m\n\u001b[32m    259\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward_actor\u001b[39m(\u001b[38;5;28mself\u001b[39m, features: th.Tensor) -> th.Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m260\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpolicy_net\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/torch/nn/modules/module.py:1773\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1771\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1772\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1773\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/torch/nn/modules/module.py:1784\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1779\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1780\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1781\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1782\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1783\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1784\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1786\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1787\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/torch/nn/modules/container.py:244\u001b[39m, in \u001b[36mSequential.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m    242\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[32m    243\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m         \u001b[38;5;28minput\u001b[39m = \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m    245\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/torch/nn/modules/module.py:1773\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1771\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1772\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1773\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/torch/nn/modules/module.py:1784\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1779\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1780\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1781\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1782\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1783\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1784\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1786\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1787\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/research/IPP/IPP/IPP/lib/python3.12/site-packages/torch/nn/modules/linear.py:125\u001b[39m, in \u001b[36mLinear.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m    124\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m125\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[31mRuntimeError\u001b[39m: Tensor for argument input is on cpu but expected on mps"
     ]
    }
   ],
   "source": [
    "# Original sweep agent call (commented out for reference)\n",
    "# wandb.agent(sweep_id, train)\n",
    "\n",
    "# Direct training call for single agent\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ipp_env3 (3.12.9)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
