{"cells":[{"cell_type":"code","execution_count":13,"metadata":{"id":"gR_wtr_-HMqU","executionInfo":{"status":"ok","timestamp":1748991812312,"user_tz":420,"elapsed":25,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[],"source":["import random\n","import numpy as np\n","import matplotlib.pyplot as plt\n","\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from torch.distributions import Categorical\n","\n","import os"]},{"cell_type":"code","execution_count":14,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":687,"status":"ok","timestamp":1748991813002,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"},"user_tz":420},"id":"eXzcC4PchI0n","outputId":"a711d723-a856-4bb8-f1c1-68dbe89e773c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","gdrive_model_dir = \"/content/drive/MyDrive/Colab_Models\""]},{"cell_type":"code","execution_count":15,"metadata":{"id":"O3HxdeADHQmI","executionInfo":{"status":"ok","timestamp":1748991813008,"user_tz":420,"elapsed":3,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[],"source":["# --- Environment Class ---\n","class FruitPickingEnv:\n","    def __init__(self, max_steps=8, p_ripe = 0.125, p_unripe = 0.0625, p_supervised = 0.0625,\n","                 ripe_reward = 1, unripe_reward = -1, bump_reward = -0.25):\n","\n","        self.base_map_template = [\n","            \"WWWWWWWWWWWWW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WGGTGG.GGTGGW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WWWWWWWWWWWWW\"\n","        ]\n","        self.height = len(self.base_map_template)\n","        self.width = len(self.base_map_template[0])\n","        self.action_space_n = 4\n","        self.observation_space_shape = (5, self.height, self.width)\n","        self.max_steps = max_steps\n","        self.p_ripe = p_ripe\n","        self.p_unripe = p_unripe\n","        self.p_supervised = p_supervised\n","        self.ripe_reward = ripe_reward\n","        self.unripe_reward = unripe_reward\n","        self.bump_reward = bump_reward\n","        self.R_positions = []\n","        self.U_positions = []\n","        self.S_positions = []\n","\n","        self.device_env = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","\n","        self.reset()\n","\n","\n","    def reset(self):\n","        self.R_positions = []\n","        self.U_positions = []\n","        self.S_positions = []\n","        current_map_list = [list(row) for row in self.base_map_template]\n","\n","        starting_height = random.randint(1, self.height - 2)\n","        starting_width = random.randint(1, self.width - 2)\n","        while (starting_height, starting_width) in ((3,3),(3,9)):\n","            starting_height = random.randint(1, self.height - 2)\n","            starting_width = random.randint(1, self.width - 2)\n","        self.start_pos = (starting_height, starting_width)\n","\n","        p_ripe_scale = random.random()\n","        p_unripe_scale = random.random()\n","        p_supervised_scale = random.random()\n","        p_ripe_new = self.p_ripe * p_ripe_scale\n","        p_unripe_new = self.p_unripe * p_unripe_scale\n","        p_supervised_new = self.p_supervised * p_supervised_scale\n","\n","        for r_idx in range(self.height):\n","            for c_idx in range(self.width):\n","                if current_map_list[r_idx][c_idx] == 'G':\n","                    chance = random.random()\n","                    if chance < p_ripe_new:\n","                        current_map_list[r_idx][c_idx] = 'R'\n","                        self.R_positions.append((r_idx, c_idx))\n","                    elif chance < p_ripe_new + p_unripe_new:\n","                        current_map_list[r_idx][c_idx] = 'U'\n","                        self.U_positions.append((r_idx, c_idx))\n","                    elif chance < p_ripe_new + p_unripe_new + p_supervised_new:\n","                        current_map_list[r_idx][c_idx] = 'S'\n","                        self.S_positions.append((r_idx, c_idx))\n","\n","        if (starting_height, starting_width) in self.R_positions:\n","            self.R_positions.remove((starting_height, starting_width))\n","            current_map_list[starting_height][starting_width] = 'G'\n","        elif (starting_height, starting_width) in self.U_positions:\n","            self.U_positions.remove((starting_height, starting_width))\n","            current_map_list[starting_height][starting_width] = 'G'\n","        elif (starting_height, starting_width) in self.S_positions:\n","            self.S_positions.remove((starting_height, starting_width))\n","            current_map_list[starting_height][starting_width] = 'G'\n","\n","        self.grid_map = [\"\".join(row_list) for row_list in current_map_list]\n","        self.agent_pos = self.start_pos\n","        self.current_step = 0\n","        self.switched_policy_active = True\n","        return self._get_obs()\n","\n","    def _get_obs(self):\n","        grid_obs = np.zeros((self.height, self.width, 5), dtype=np.float32)\n","        for r_idx in range(self.height):\n","            for c_idx in range(self.width):\n","                pos = (r_idx, c_idx)\n","                if r_idx in (0,self.height - 1) or c_idx in (0, self.width - 1) or pos in ((3,3),(3,9)): grid_obs[r_idx, c_idx, 0] = 1\n","                elif pos in self.R_positions: grid_obs[r_idx, c_idx, 1] = 1\n","                elif pos in self.U_positions: grid_obs[r_idx, c_idx, 2] = 1\n","                elif pos in self.S_positions: grid_obs[r_idx, c_idx, 3] = 1\n","        grid_obs[self.agent_pos[0], self.agent_pos[1], 4] = 1\n","\n","        return np.transpose(grid_obs, (2, 0, 1))\n","\n","    def step(self, action, update_acceptance):\n","        self.current_step += 1\n","        current_reward = 0.0\n","        done = False\n","        info = {\n","            'type_picked': None,\n","            'bumped': False\n","        }\n","\n","        action_to_take = action\n","\n","        base_direction = action_to_take\n","        r, c = self.agent_pos\n","        if base_direction == 0:\n","            r -= 1\n","        elif base_direction == 1:\n","            r += 1\n","        elif base_direction == 2:\n","            c -= 1\n","        elif base_direction == 3:\n","            c += 1\n","\n","        if 0 <= r < self.height and 0 <= c < self.width and self.grid_map[r][c] != 'W' and self.grid_map[r][c] != 'T':\n","            self.agent_pos = (r, c)\n","        else:\n","            info['bumped'] = True\n","\n","        if self.agent_pos in self.R_positions:\n","                self.R_positions.remove(self.agent_pos)\n","                info['type_picked'] = 'R'\n","        elif self.agent_pos in self.U_positions:\n","                self.U_positions.remove(self.agent_pos)\n","                info['type_picked'] = 'U'\n","        elif self.agent_pos in self.S_positions:\n","                self.S_positions.remove(self.agent_pos)\n","                info['type_picked'] = 'S'\n","        if self.current_step >= self.max_steps:\n","            done = True\n","        return self._get_obs(), info['type_picked'], info['bumped'], done\n","\n","\n","    def render(self):\n","        grid_map = [list(row) for row in self.grid_map]\n","        for r_idx in range(self.height):\n","            for c_idx in range(self.width):\n","                pos = (r_idx, c_idx)\n","                if r_idx in (0,self.height - 1) or c_idx in (0, self.width - 1): grid_map[r_idx][c_idx] = 'W'\n","                elif pos in ((3,3),(3,9)): grid_map[r_idx][c_idx] = 'T'\n","                elif pos in self.R_positions: grid_map[r_idx][c_idx] = 'R'\n","                elif pos in self.U_positions: grid_map[r_idx][c_idx] = 'U'\n","                elif pos in self.S_positions: grid_map[r_idx][c_idx] = 'S'\n","                else: grid_map[r_idx][c_idx] = '.'\n","        r_pos, c_pos = self.agent_pos\n","        grid_map[r_pos][c_pos] = 'A'\n","\n","        print(f\"Step: {self.current_step}, Max Steps: {self.max_steps}\")\n","        for row_list in grid_map:\n","            print(\"\".join(row_list))\n","        print(\"-\" * (self.width + 20))"]},{"cell_type":"code","execution_count":16,"metadata":{"id":"7nq5iykQb2B1","executionInfo":{"status":"ok","timestamp":1748991813013,"user_tz":420,"elapsed":1,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[],"source":["class ActorCritic(nn.Module):\n","    def __init__(self, input_shape, n_actions, num_hidden_layers=2, hidden_size=256):\n","        super(ActorCritic, self).__init__()\n","\n","        channels, height, width = input_shape\n","\n","        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=8, kernel_size=3, stride=1, padding=1)\n","        self.gn1 = nn.GroupNorm(num_groups=1, num_channels=8)\n","        self.conv2 = nn.Conv2d(in_channels=8, out_channels=2, kernel_size=3, stride=1, padding=1)\n","        self.gn2 = nn.GroupNorm(num_groups=1, num_channels=2)\n","        conv_out_size = 2 * height * width\n","        current_dim = conv_out_size\n","\n","        self.hidden_layers_fc = nn.ModuleList()\n","        self.layer_norms_fc = nn.ModuleList()\n","\n","        if num_hidden_layers > 0:\n","            for _ in range(num_hidden_layers):\n","                self.hidden_layers_fc.append(nn.Linear(current_dim, hidden_size))\n","                self.layer_norms_fc.append(nn.LayerNorm(hidden_size))\n","                current_dim = hidden_size\n","\n","\n","        self.actor = nn.Linear(current_dim, n_actions)\n","        self.critic_r = nn.Linear(current_dim, 1)\n","        self.critic_u = nn.Linear(current_dim, 1)\n","        self.critic_b = nn.Linear(current_dim, 1)\n","\n","    def forward(self, state):\n","        x = self.conv1(state)\n","        x = self.gn1(x)\n","        x = F.relu(x)\n","        x = self.conv2(x)\n","        x = self.gn2(x)\n","        x = F.relu(x)\n","        x = x.view(x.size(0), -1)\n","\n","        for hidden_layer, norm_layer in zip(self.hidden_layers_fc, self.layer_norms_fc):\n","            x = F.relu(hidden_layer(x))\n","            x = norm_layer(x)\n","\n","        action_logits = self.actor(x)\n","        state_value_r = self.critic_r(x)\n","        state_value_u = self.critic_u(x)\n","        state_value_b = self.critic_b(x)\n","\n","        return action_logits, state_value_r, state_value_u, state_value_b"]},{"cell_type":"code","execution_count":17,"metadata":{"id":"yZ6Kpsny3k6o","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1748991813059,"user_tz":420,"elapsed":43,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}},"outputId":"0970ef98-aa16-432b-c95f-1e601ef8e9ca"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["<torch._C.Generator at 0x7f9b6814db70>"]},"metadata":{},"execution_count":17}],"source":["#Set seed\n","random.seed(0)\n","torch.manual_seed(0)"]},{"cell_type":"code","execution_count":18,"metadata":{"id":"l1CrhueXx73J","executionInfo":{"status":"ok","timestamp":1748991813069,"user_tz":420,"elapsed":8,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[],"source":["# Experiment Setup\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","lr = 0.001\n","gamma = 0.95\n","max_steps_per_episode = 8\n","hidden_layer_size = 256\n","n_hidden = 2\n","log_barrier_coeff = 0.001\n","accumulation_steps = 8\n","env = FruitPickingEnv(max_steps=max_steps_per_episode + 1, p_ripe = 0.33, p_unripe = 0.33,  p_supervised=0.33, ripe_reward=1, unripe_reward=-1, bump_reward=-0.25)\n","model = ActorCritic(env.observation_space_shape, env.action_space_n, num_hidden_layers=n_hidden, hidden_size=hidden_layer_size).to(device)\n","optimizer = optim.Adam(model.parameters(), lr=lr)"]},{"cell_type":"code","execution_count":19,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"T225RNy8ru0J","outputId":"5183b2a8-f812-4327-b355-c3c4017eb931","executionInfo":{"status":"ok","timestamp":1748993863075,"user_tz":420,"elapsed":1546094,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Episode 4096/131072, Avg Ripe: 0.53, Avg Unripe: 1.08, Avg Bump: 1.34, LB: 11.755\n","Episode 8192/131072, Avg Ripe: 0.64, Avg Unripe: 0.68, Avg Bump: 0.17, LB: 13.448\n","Episode 12288/131072, Avg Ripe: 1.16, Avg Unripe: 0.76, Avg Bump: 0.25, LB: 14.007\n","Episode 16384/131072, Avg Ripe: 1.54, Avg Unripe: 0.73, Avg Bump: 0.26, LB: 14.622\n","Episode 20480/131072, Avg Ripe: 1.72, Avg Unripe: 0.59, Avg Bump: 0.20, LB: 17.559\n","Episode 24576/131072, Avg Ripe: 1.84, Avg Unripe: 0.52, Avg Bump: 0.17, LB: 17.137\n","Episode 28672/131072, Avg Ripe: 1.96, Avg Unripe: 0.51, Avg Bump: 0.13, LB: 21.835\n","Episode 32768/131072, Avg Ripe: 2.09, Avg Unripe: 0.47, Avg Bump: 0.09, LB: 20.221\n","Episode 36864/131072, Avg Ripe: 2.14, Avg Unripe: 0.45, Avg Bump: 0.09, LB: 20.059\n","Episode 40960/131072, Avg Ripe: 2.23, Avg Unripe: 0.42, Avg Bump: 0.09, LB: 21.993\n","Episode 45056/131072, Avg Ripe: 2.28, Avg Unripe: 0.36, Avg Bump: 0.13, LB: 17.167\n","Episode 49152/131072, Avg Ripe: 2.33, Avg Unripe: 0.31, Avg Bump: 0.08, LB: 21.160\n","Episode 53248/131072, Avg Ripe: 2.36, Avg Unripe: 0.29, Avg Bump: 0.07, LB: 22.059\n","Episode 57344/131072, Avg Ripe: 2.42, Avg Unripe: 0.27, Avg Bump: 0.11, LB: 15.876\n","Episode 61440/131072, Avg Ripe: 2.42, Avg Unripe: 0.26, Avg Bump: 0.08, LB: 20.700\n","Episode 65536/131072, Avg Ripe: 2.45, Avg Unripe: 0.24, Avg Bump: 0.06, LB: 18.608\n","Episode 69632/131072, Avg Ripe: 2.51, Avg Unripe: 0.23, Avg Bump: 0.06, LB: 21.663\n","Episode 73728/131072, Avg Ripe: 2.48, Avg Unripe: 0.20, Avg Bump: 0.05, LB: 19.576\n","Episode 77824/131072, Avg Ripe: 2.48, Avg Unripe: 0.19, Avg Bump: 0.06, LB: 19.468\n","Episode 81920/131072, Avg Ripe: 2.55, Avg Unripe: 0.20, Avg Bump: 0.05, LB: 19.286\n","Episode 86016/131072, Avg Ripe: 2.50, Avg Unripe: 0.21, Avg Bump: 0.06, LB: 25.705\n","Episode 90112/131072, Avg Ripe: 2.54, Avg Unripe: 0.19, Avg Bump: 0.07, LB: 17.235\n","Episode 94208/131072, Avg Ripe: 2.50, Avg Unripe: 0.19, Avg Bump: 0.05, LB: 23.192\n","Episode 98304/131072, Avg Ripe: 2.53, Avg Unripe: 0.20, Avg Bump: 0.05, LB: 21.502\n","Episode 102400/131072, Avg Ripe: 2.52, Avg Unripe: 0.18, Avg Bump: 0.05, LB: 23.591\n","Episode 106496/131072, Avg Ripe: 2.50, Avg Unripe: 0.19, Avg Bump: 0.05, LB: 19.428\n","Episode 110592/131072, Avg Ripe: 2.52, Avg Unripe: 0.18, Avg Bump: 0.04, LB: 21.034\n","Episode 114688/131072, Avg Ripe: 2.49, Avg Unripe: 0.17, Avg Bump: 0.04, LB: 18.786\n","Episode 118784/131072, Avg Ripe: 2.50, Avg Unripe: 0.17, Avg Bump: 0.04, LB: 14.258\n","Episode 122880/131072, Avg Ripe: 2.56, Avg Unripe: 0.19, Avg Bump: 0.03, LB: 19.822\n","Episode 126976/131072, Avg Ripe: 2.53, Avg Unripe: 0.17, Avg Bump: 0.04, LB: 15.388\n","Episode 131072/131072, Avg Ripe: 2.53, Avg Unripe: 0.17, Avg Bump: 0.04, LB: 22.242\n"]}],"source":["n_episodes_scale = 32\n","n_episodes_forced = 4096 * n_episodes_scale\n","# Forced Policy Pretraining\n","ripe_count_forced = []\n","unripe_count_forced = []\n","bump_count_forced = []\n","optimizer.zero_grad()\n","for episode in range(n_episodes_forced):\n","    state = env.reset()\n","    done = False\n","\n","    ep_action_log_probs = []\n","    ep_state_values_r = []\n","    ep_state_values_u = []\n","    ep_state_values_b = []\n","    ep_log_barriers = []\n","    ep_picks = []\n","    ep_bumps = []\n","\n","    # Run episode\n","    for step in range(env.max_steps):\n","        obs_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n","        action_logits,  state_value_r, state_value_u, state_value_b = model(obs_tensor)\n","\n","        action_probs = F.softmax(action_logits, dim=-1)\n","        dist = Categorical(action_probs)\n","        action = dist.sample()\n","\n","        next_obs, type_picked, bumped, done = env.step(action.item(), 0)\n","\n","        ep_action_log_probs.append(dist.log_prob(action).squeeze())\n","        ep_state_values_r.append(state_value_r.squeeze())\n","        ep_state_values_u.append(state_value_u.squeeze())\n","        ep_state_values_b.append(state_value_b.squeeze())\n","        ep_picks.append(type_picked)\n","        ep_bumps.append(bumped)\n","        log_barrier = -torch.sum(torch.log(dist.probs + 1e-8))\n","        ep_log_barriers.append(log_barrier.squeeze())\n","\n","        state = next_obs\n","        if done:\n","            break\n","\n","    ripe_count_forced.append(ep_picks.count('R'))\n","    unripe_count_forced.append(ep_picks.count('U') + ep_picks.count('S'))\n","    bump_count_forced.append(ep_bumps.count(True))\n","\n","\n","    # Make update\n","    advantages = []\n","    for i in range(len(ep_action_log_probs) - 1):\n","        advantage = torch.stack([\n","                    gamma * ep_state_values_r[i + 1].detach() - ep_state_values_r[i],\n","                    gamma * ep_state_values_u[i + 1].detach() - ep_state_values_u[i],\n","                    gamma * ep_state_values_b[i + 1].detach() - ep_state_values_b[i]\n","                ])\n","        if ep_picks[i] == 'R':\n","            advantage[0] += 1\n","        elif ep_picks[i] == 'U' or ep_picks[i] == 'S':\n","            advantage[1] += 1\n","        if ep_bumps[i]:\n","            advantage[2] += 1\n","        advantages.append(advantage)\n","\n","\n","    advantages_stack = torch.stack(advantages).to(device)\n","    weights = torch.tensor([env.ripe_reward, env.unripe_reward, env.bump_reward], dtype=torch.float32).to(device)\n","    actor_loss = -(torch.stack(ep_action_log_probs[:-1]).squeeze() * (advantages_stack * weights).sum(dim=1).detach()).mean()\n","    critic_loss = (advantages_stack.pow(2)).mean()\n","    log_barrier_loss = torch.stack(ep_log_barriers[:-1]).mean()\n","\n","    total_rl_loss = (1.0 * actor_loss + 0.5 * critic_loss + log_barrier_coeff * log_barrier_loss)/accumulation_steps\n","    total_rl_loss.backward()\n","    if (episode + 1) % accumulation_steps == 0:\n","        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)\n","        optimizer.step()\n","        optimizer.zero_grad()\n","\n","    if (episode + 1) % 4096 == 0:\n","        torch.cuda.empty_cache()\n","\n","    if (episode + 1) % (n_episodes_forced // n_episodes_scale) == 0:\n","        avg_interval = n_episodes_forced // n_episodes_scale\n","        avg_ripe = np.mean(ripe_count_forced[-avg_interval:]) if ripe_count_forced else 0\n","        avg_unripe = np.mean(unripe_count_forced[-avg_interval:]) if unripe_count_forced else 0\n","        avg_bump = np.mean(bump_count_forced[-avg_interval:]) if bump_count_forced else 0\n","        print(f\"Episode {episode+1}/{n_episodes_forced}, Avg Ripe: {avg_ripe:.2f}, Avg Unripe: {avg_unripe:.2f}, Avg Bump: {avg_bump:.2f}, LB: {log_barrier_loss.item():.3f}\")"]},{"cell_type":"code","source":["torch.save(model.state_dict(), gdrive_model_dir + '/forced_value_5DGN_8step_1.pth')"],"metadata":{"id":"FBm1DTmvF9cn","executionInfo":{"status":"ok","timestamp":1748993863096,"user_tz":420,"elapsed":2,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"execution_count":20,"outputs":[]},{"cell_type":"code","source":["# forced_5D_8step_1_dict = torch.load(gdrive_model_dir + '/forced_value_5D_8step_1.pth')\n","# model.load_state_dict(forced_5D_8step_1_dict)"],"metadata":{"id":"tCZt_3wyM8Sv","executionInfo":{"status":"ok","timestamp":1748993863115,"user_tz":420,"elapsed":2,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"execution_count":21,"outputs":[]},{"cell_type":"code","execution_count":22,"metadata":{"id":"rtty5p9y6t_a","executionInfo":{"status":"ok","timestamp":1748993863134,"user_tz":420,"elapsed":5,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[],"source":["# # Lower fruit spawn rates\n","# n_episodes_scale_2 = 32\n","# n_episodes_forced_2 = 4096 * n_episodes_scale_2\n","# ripe_count_forced = []\n","# unripe_count_forced = []\n","# bump_count_forced = []\n","# optimizer.zero_grad()\n","# for episode in range(n_episodes_forced_2):\n","#     state = env.reset()\n","#     done = False\n","\n","#     ep_action_log_probs = []\n","#     ep_state_values_r = []\n","#     ep_state_values_u = []\n","#     ep_state_values_b = []\n","#     ep_log_barriers = []\n","#     ep_picks = []\n","#     ep_bumps = []\n","\n","#     # Run episode\n","#     for step in range(env.max_steps):\n","#         obs_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n","#         action_logits,  state_value_r, state_value_u, state_value_b = model(obs_tensor)\n","\n","#         action_probs = F.softmax(action_logits, dim=-1)\n","#         dist = Categorical(action_probs)\n","#         action = dist.sample()\n","\n","#         next_obs, type_picked, bumped, done = env.step(action.item(), 0)\n","\n","#         ep_action_log_probs.append(dist.log_prob(action).squeeze())\n","#         ep_state_values_r.append(state_value_r.squeeze())\n","#         ep_state_values_u.append(state_value_u.squeeze())\n","#         ep_state_values_b.append(state_value_b.squeeze())\n","#         ep_picks.append(type_picked)\n","#         ep_bumps.append(bumped)\n","#         log_barrier = -torch.sum(torch.log(dist.probs + 1e-8))\n","#         ep_log_barriers.append(log_barrier.squeeze())\n","\n","#         state = next_obs\n","#         if done:\n","#             break\n","\n","#     ripe_count_forced.append(ep_picks.count('R'))\n","#     unripe_count_forced.append(ep_picks.count('U') + ep_picks.count('S'))\n","#     bump_count_forced.append(ep_bumps.count(True))\n","\n","\n","#     # Make update\n","#     advantages = []\n","#     for i in range(len(ep_action_log_probs) - 1):\n","#         advantage = torch.stack([\n","#                     gamma * ep_state_values_r[i + 1].detach() - ep_state_values_r[i],\n","#                     gamma * ep_state_values_u[i + 1].detach() - ep_state_values_u[i],\n","#                     gamma * ep_state_values_b[i + 1].detach() - ep_state_values_b[i]\n","#                 ])\n","#         if ep_picks[i] == 'R':\n","#             advantage[0] += 1\n","#         elif ep_picks[i] == 'U' or ep_picks[i] == 'S':\n","#             advantage[1] += 1\n","#         if ep_bumps[i]:\n","#             advantage[2] += 1\n","#         advantages.append(advantage)\n","\n","\n","#     advantages_stack = torch.stack(advantages).to(device)\n","#     weights = torch.tensor([env.ripe_reward, env.unripe_reward, env.bump_reward], dtype=torch.float32).to(device)\n","#     actor_loss = -(torch.stack(ep_action_log_probs[:-1]).squeeze() * (advantages_stack * weights).sum(dim=1).detach()).mean()\n","#     critic_loss = (advantages_stack.pow(2) * weights.abs()).mean()\n","#     log_barrier_loss = torch.stack(ep_log_barriers[:-1]).mean()\n","\n","#     total_rl_loss = (1.0 * actor_loss + 0.5 * critic_loss + log_barrier_coeff * log_barrier_loss)/accumulation_steps\n","#     total_rl_loss.backward()\n","#     if (episode + 1) % accumulation_steps == 0:\n","#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1)\n","#         optimizer.step()\n","#         optimizer.zero_grad()\n","\n","#     if (episode + 1) % 4096 == 0:\n","#         torch.cuda.empty_cache()\n","\n","#     if (episode + 1) % (n_episodes_forced_2 // n_episodes_scale_2) == 0:\n","#         avg_interval = n_episodes_forced_2 // n_episodes_scale_2\n","#         avg_ripe = np.mean(ripe_count_forced[-avg_interval:]) if ripe_count_forced else 0\n","#         avg_unripe = np.mean(unripe_count_forced[-avg_interval:]) if unripe_count_forced else 0\n","#         avg_bump = np.mean(bump_count_forced[-avg_interval:]) if bump_count_forced else 0\n","#         env.p_ripe -= 0.1 / n_episodes_scale_2\n","#         env.p_unripe -= 0.1 / n_episodes_scale_2\n","#         env.p_supervised -= 0.1 / n_episodes_scale_2\n","#         print(f\"Episode {episode+1}/{n_episodes_forced_2},  Avg Ripe: {avg_ripe:.2f}, Avg Unripe: {avg_unripe:.2f}, Avg Bump: {avg_bump:.2f}, LB: {log_barrier_loss.item():.3f}, P_Ripe: {env.p_ripe:.3f}, P_Unripe: {env.p_unripe:.3f}, P_Supervised: {env.p_supervised:.3f}\")\n"]},{"cell_type":"code","execution_count":23,"metadata":{"id":"godznhVn6nIi","executionInfo":{"status":"ok","timestamp":1748993863165,"user_tz":420,"elapsed":1,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"outputs":[],"source":["# torch.save(model.state_dict(), gdrive_model_dir + '/forced_value_5D_8step_2.pth')"]},{"cell_type":"code","source":["# forced_5D_8step_2_dict = torch.load(gdrive_model_dir + '/forced_value_5D_8step_2.pth')\n","# model.load_state_dict(forced_5D_8step_2_dict)"],"metadata":{"id":"HrTuLhq_2rX1","executionInfo":{"status":"ok","timestamp":1748993863209,"user_tz":420,"elapsed":27,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}}},"execution_count":24,"outputs":[]}],"metadata":{"colab":{"machine_shape":"hm","provenance":[{"file_id":"1z4LIZ2juZBvxOr5kxz4TXfq9kGbQ4DYw","timestamp":1748984027375},{"file_id":"1J1NLBR5HGrQYnPjR_ZDGoN2CJKGjnOSK","timestamp":1748280680241}],"authorship_tag":"ABX9TyO/egK3X+e/kmLHB40QLzt3"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}