{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"gR_wtr_-HMqU"},"outputs":[],"source":["import random\n","import numpy as np\n","import matplotlib.pyplot as plt\n","\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from torch.distributions import Categorical\n","\n","import os"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":938,"status":"ok","timestamp":1749605621427,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"},"user_tz":420},"id":"eXzcC4PchI0n","outputId":"d3c98020-6854-4da8-b798-e3bb822f9ee0"},"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","gdrive_model_dir = \"/content/drive/MyDrive/Colab_Models\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"O3HxdeADHQmI"},"outputs":[],"source":["# --- Environment Class ---\n","class FruitPickingEnv:\n","    def __init__(self, max_steps=8, p_ripe = 0.125, p_unripe = 0.0625, p_supervised = 0.0625,\n","                 ripe_reward = 1, unripe_reward = -1, bump_reward = -0.25):\n","\n","        self.base_map_template = [\n","            \"WWWWWWWWWWWWW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WGGTGG.GGTGGW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WGGGGG.GGGGGW\",\n","            \"WWWWWWWWWWWWW\"\n","        ]\n","        self.height = len(self.base_map_template)\n","        self.width = len(self.base_map_template[0])\n","        self.action_space_n = 8\n","        self.observation_space_shape = (5, self.height, self.width)\n","        self.max_steps = max_steps\n","        self.p_ripe = p_ripe\n","        self.p_unripe = p_unripe\n","        self.p_supervised = p_supervised\n","        self.ripe_reward = ripe_reward\n","        self.unripe_reward = unripe_reward\n","        self.bump_reward = bump_reward\n","        self.R_positions = []\n","        self.U_positions = []\n","        self.S_positions = []\n","\n","        self.device_env = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","\n","        self.reset()\n","\n","\n","    def reset(self):\n","        self.R_positions = []\n","        self.U_positions = []\n","        self.S_positions = []\n","        current_map_list = [list(row) for row in self.base_map_template]\n","\n","\n","        starting_height = random.randint(1, self.height - 2)\n","        starting_width = random.randint(1, self.width - 2)\n","        while (starting_height, starting_width) in ((3,3),(3,9)):\n","            starting_height = random.randint(1, self.height - 2)\n","            starting_width = random.randint(1, self.width - 2)\n","        self.start_pos = (starting_height, starting_width)\n","\n","        p_ripe_scale = random.random()\n","        p_unripe_scale = random.random()\n","        p_supervised_scale = random.random()\n","        p_ripe_new = self.p_ripe * p_ripe_scale\n","        p_unripe_new = self.p_unripe * p_unripe_scale\n","        p_supervised_new = self.p_supervised * p_supervised_scale\n","\n","        for r_idx in range(self.height):\n","            for c_idx in range(self.width):\n","                if current_map_list[r_idx][c_idx] == 'G':\n","                    chance = random.random()\n","                    if chance < p_ripe_new:\n","                        current_map_list[r_idx][c_idx] = 'R'\n","                        self.R_positions.append((r_idx, c_idx))\n","                    elif chance < p_ripe_new + p_unripe_new:\n","                        current_map_list[r_idx][c_idx] = 'U'\n","                        self.U_positions.append((r_idx, c_idx))\n","                    elif chance < p_ripe_new + p_unripe_new + p_supervised_new:\n","                        current_map_list[r_idx][c_idx] = 'S'\n","                        self.S_positions.append((r_idx, c_idx))\n","\n","        if (starting_height, starting_width) in self.R_positions:\n","            self.R_positions.remove((starting_height, starting_width))\n","            current_map_list[starting_height][starting_width] = 'G'\n","        elif (starting_height, starting_width) in self.U_positions:\n","            self.U_positions.remove((starting_height, starting_width))\n","            current_map_list[starting_height][starting_width] = 'G'\n","        elif (starting_height, starting_width) in self.S_positions:\n","            self.S_positions.remove((starting_height, starting_width))\n","            current_map_list[starting_height][starting_width] = 'G'\n","\n","        self.grid_map = [\"\".join(row_list) for row_list in current_map_list]\n","        self.agent_pos = self.start_pos\n","        self.current_step = 0\n","        self.switched_policy_active = False\n","        return self._get_obs()\n","\n","    def _get_obs(self):\n","        grid_obs = np.zeros((self.height, self.width, 5), dtype=np.float32)\n","        for r_idx in range(self.height):\n","            for c_idx in range(self.width):\n","                pos = (r_idx, c_idx)\n","                if r_idx in (0,self.height - 1) or c_idx in (0, self.width - 1) or pos in ((3,3),(3,9)): grid_obs[r_idx, c_idx, 0] = 1\n","                elif pos in self.R_positions: grid_obs[r_idx, c_idx, 1] = 1\n","                elif pos in self.U_positions: grid_obs[r_idx, c_idx, 2] = 1\n","                elif pos in self.S_positions: grid_obs[r_idx, c_idx, 3] = 1\n","        grid_obs[self.agent_pos[0], self.agent_pos[1], 4] = 1\n","\n","        return np.transpose(grid_obs, (2, 0, 1))\n","\n","    def step(self, action):\n","        self.current_step += 1\n","        current_reward = 0.0\n","        done = False\n","        info = {\n","            'proper_signal': False,\n","            'switched': False,\n","            'type_picked': None\n","        }\n","\n","        base_direction = action % 4\n","        r, c = self.agent_pos\n","        if base_direction == 0:\n","            r -= 1\n","        elif base_direction == 1:\n","            r += 1\n","        elif base_direction == 2:\n","            c -= 1\n","        elif base_direction == 3:\n","            c += 1\n","\n","        if 0 <= r < self.height and 0 <= c < self.width and self.grid_map[r][c] != 'W' and self.grid_map[r][c] != 'T':\n","            self.agent_pos = (r, c)\n","        else:\n","            current_reward += self.bump_reward\n","\n","        if self.agent_pos in self.R_positions:\n","                current_reward = self.ripe_reward\n","                self.R_positions.remove(self.agent_pos)\n","                info['type_picked'] = 'R'\n","        elif self.agent_pos in self.U_positions:\n","                current_reward = self.ripe_reward\n","                self.U_positions.remove(self.agent_pos)\n","                info['type_picked'] = 'U'\n","        elif self.agent_pos in self.S_positions:\n","                current_reward = self.ripe_reward\n","                self.S_positions.remove(self.agent_pos)\n","                info['type_picked'] = 'S'\n","\n","\n","        if self.current_step >= self.max_steps:\n","            done = True\n","        return self._get_obs(), current_reward, done, info\n","\n","\n","    def render(self):\n","        grid_map = [list(row) for row in self.grid_map]\n","        for r_idx in range(self.height):\n","            for c_idx in range(self.width):\n","                pos = (r_idx, c_idx)\n","                if r_idx in (0,self.height - 1) or c_idx in (0, self.width - 1): grid_map[r_idx][c_idx] = 'W'\n","                elif pos in ((3,3),(3,9)): grid_map[r_idx][c_idx] = 'T'\n","                elif pos in self.R_positions: grid_map[r_idx][c_idx] = 'R'\n","                elif pos in self.U_positions: grid_map[r_idx][c_idx] = 'U'\n","                elif pos in self.S_positions: grid_map[r_idx][c_idx] = 'S'\n","                else: grid_map[r_idx][c_idx] = '.'\n","        r_pos, c_pos = self.agent_pos\n","        grid_map[r_pos][c_pos] = 'A'\n","\n","        print(f\"Step: {self.current_step}, Max Steps: {self.max_steps}\")\n","        for row_list in grid_map:\n","            print(\"\".join(row_list))\n","        print(\"-\" * (self.width + 20))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"etTRfzSiuTZP"},"outputs":[],"source":["class AC_QAC_Hybrid(nn.Module):\n","    def __init__(self, input_dims, n_actions, num_hidden_layers=2, hidden_size=256):\n","        super(AC_QAC_Hybrid, self).__init__()\n","\n","        channels, height, width = input_dims\n","\n","        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=8, kernel_size=3, stride=1, padding=1)\n","        self.gn1 = nn.GroupNorm(num_groups=1, num_channels=8)\n","        self.conv2 = nn.Conv2d(in_channels=8, out_channels=2, kernel_size=3, stride=1, padding=1)\n","        self.gn2 = nn.GroupNorm(num_groups=1, num_channels=2)\n","        conv_out_size = 2 * height * width\n","        current_dim = conv_out_size\n","\n","        self.hidden_layers_fc = nn.ModuleList()\n","        self.layer_norms_fc = nn.ModuleList()\n","\n","        if num_hidden_layers > 0:\n","            for _ in range(num_hidden_layers):\n","                self.hidden_layers_fc.append(nn.Linear(current_dim, hidden_size))\n","                self.layer_norms_fc.append(nn.LayerNorm(hidden_size))\n","                current_dim = hidden_size\n","\n","        self.actor = nn.Linear(current_dim, n_actions)\n","        self.critic = nn.Linear(current_dim, 1)\n","        self.critic_q = nn.Linear(current_dim, n_actions)\n","\n","    def forward(self, state):\n","        x = self.conv1(state)\n","        x = self.gn1(x)\n","        x = F.relu(x)\n","        x = self.conv2(x)\n","        x = self.gn2(x)\n","        x = F.relu(x)\n","        x = x.view(x.size(0), -1)\n","        for hidden_layer, norm_layer in zip(self.hidden_layers_fc, self.layer_norms_fc):\n","            x = F.relu(hidden_layer(x))\n","            x = norm_layer(x)\n","\n","        action_logits = self.actor(x)\n","        state_value = self.critic(x)\n","        action_state_values = self.critic_q(x)\n","\n","        return action_logits, state_value, action_state_values"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"_LgcnUU7yhP5"},"outputs":[],"source":["class ActorCritic(nn.Module):\n","    def __init__(self, input_dims, n_actions, num_hidden_layers=2, hidden_size=256):\n","        super(ActorCritic, self).__init__()\n","\n","        channels, height, width = input_dims\n","\n","        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=8, kernel_size=3, stride=1, padding=1)\n","        self.gn1 = nn.GroupNorm(num_groups=1, num_channels=8)\n","        self.conv2 = nn.Conv2d(in_channels=8, out_channels=2, kernel_size=3, stride=1, padding=1)\n","        self.gn2 = nn.GroupNorm(num_groups=1, num_channels=2)\n","        conv_out_size = 2 * height * width\n","        current_dim = conv_out_size\n","\n","        self.hidden_layers_fc = nn.ModuleList()\n","        self.layer_norms_fc = nn.ModuleList()\n","\n","        if num_hidden_layers > 0:\n","            for _ in range(num_hidden_layers):\n","                self.hidden_layers_fc.append(nn.Linear(current_dim, hidden_size))\n","                self.layer_norms_fc.append(nn.LayerNorm(hidden_size))\n","                current_dim = hidden_size\n","\n","        self.actor = nn.Linear(current_dim, n_actions)\n","        self.critic = nn.Linear(current_dim, 1)\n","\n","    def forward(self, state):\n","        x = self.conv1(state)\n","        x = self.gn1(x)\n","        x = F.relu(x)\n","        x = self.conv2(x)\n","        x = self.gn2(x)\n","        x = F.relu(x)\n","        x = x.view(x.size(0), -1)\n","        for hidden_layer, norm_layer in zip(self.hidden_layers_fc, self.layer_norms_fc):\n","            x = F.relu(hidden_layer(x))\n","            x = norm_layer(x)\n","\n","        action_logits = self.actor(x)\n","        state_value = self.critic(x)\n","\n","        return action_logits, state_value"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"j9r3J-CCSrg6"},"outputs":[],"source":["class QAC(nn.Module):\n","    def __init__(self, input_dims, n_actions, num_hidden_layers=2, hidden_size=256):\n","        super(QAC, self).__init__()\n","\n","        channels, height, width = input_dims\n","\n","        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=8, kernel_size=3, stride=1, padding=1)\n","        self.gn1 = nn.GroupNorm(num_groups=1, num_channels=8)\n","        self.conv2 = nn.Conv2d(in_channels=8, out_channels=2, kernel_size=3, stride=1, padding=1)\n","        self.gn2 = nn.GroupNorm(num_groups=1, num_channels=2)\n","        conv_out_size = 2 * height * width\n","        current_dim = conv_out_size\n","\n","        self.hidden_layers_fc = nn.ModuleList()\n","        self.layer_norms_fc = nn.ModuleList()\n","\n","        if num_hidden_layers > 0:\n","            for _ in range(num_hidden_layers):\n","                self.hidden_layers_fc.append(nn.Linear(current_dim, hidden_size))\n","                self.layer_norms_fc.append(nn.LayerNorm(hidden_size))\n","                current_dim = hidden_size\n","\n","        self.actor = nn.Linear(current_dim, n_actions)\n","        self.critic_q = nn.Linear(current_dim, n_actions)\n","\n","    def forward(self, state):\n","        x = self.conv1(state)\n","        x = self.gn1(x)\n","        x = F.relu(x)\n","        x = self.conv2(x)\n","        x = self.gn2(x)\n","        x = F.relu(x)\n","        x = x.view(x.size(0), -1)\n","        for hidden_layer, norm_layer in zip(self.hidden_layers_fc, self.layer_norms_fc):\n","            x = F.relu(hidden_layer(x))\n","            x = norm_layer(x)\n","\n","        action_logits = self.actor(x)\n","        action_state_values = self.critic_q(x)\n","\n","        return action_logits, action_state_values"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"l1CrhueXx73J"},"outputs":[],"source":["# Experiment Setup\n","random.seed(0)\n","torch.manual_seed(0)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","lr = 0.001\n","gamma = 0.95\n","max_steps_per_episode = 8\n","hidden_layer_size = 256\n","n_hidden = 2\n","log_barrier_coeff = 0.005\n","accumulation_steps = 8\n","env = FruitPickingEnv(max_steps=max_steps_per_episode + 1, p_ripe = 0.33, p_unripe = 0.33, p_supervised=0.33,  ripe_reward=1, unripe_reward=1, bump_reward=-0.25)\n","model = AC_QAC_Hybrid(env.observation_space_shape, env.action_space_n, num_hidden_layers=n_hidden, hidden_size=hidden_layer_size).to(device)\n","optimizer = optim.Adam(model.parameters(), lr=lr)"]},{"cell_type":"code","source":["random.seed(0)\n","torch.manual_seed(0)\n","n_episodes_scale = 32\n","n_episodes_pretrain = 4096 * n_episodes_scale\n","# Pretraining\n","# For fairness, trains both a standard critic head and corrigibility transformation critic head\n","ripe_count_pretrain = []\n","unripe_count_pretrain = []\n","optimizer.zero_grad()\n","for episode in range(n_episodes_pretrain):\n","    state = env.reset()\n","    done = False\n","\n","    ep_actions = []\n","    ep_alt_actions = []\n","    ep_action_log_probs = []\n","    ep_alt_action_log_probs = []\n","    ep_state_values = []\n","    ep_action_state_values = []\n","    ep_rewards = []\n","    ep_log_barriers = []\n","    current_episode_picks = {'R': 0, 'U': 0, 'S': 0}\n","\n","    # Run episode\n","    for step in range(env.max_steps):\n","        obs_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n","        action_logits, state_value, action_state_values = model(obs_tensor)\n","\n","        action_probs = F.softmax(action_logits, dim=-1)\n","        dist = Categorical(action_probs)\n","        action = dist.sample()\n","        if action.item() >= 4:\n","            alt_action = action - 4\n","        else:\n","            alt_action = action + 4\n","\n","        next_obs, reward, done, info = env.step(action.item())\n","\n","        ep_action_log_probs.append(dist.log_prob(action).squeeze())\n","        ep_alt_action_log_probs.append(dist.log_prob(alt_action).squeeze())\n","        ep_actions.append(action)\n","        ep_alt_actions.append(alt_action)\n","        ep_state_values.append(state_value.squeeze())\n","        ep_rewards.append(reward)\n","        log_barrier = -torch.sum(torch.log(dist.probs + 1e-8))\n","        ep_log_barriers.append(log_barrier.squeeze())\n","        ep_action_state_values.append(action_state_values.squeeze())\n","\n","        if info['type_picked']:\n","            current_episode_picks[info['type_picked']] += 1\n","\n","        state = next_obs\n","        if done:\n","            break\n","\n","    ripe_count_pretrain.append(current_episode_picks['R'])\n","    unripe_count_pretrain.append(current_episode_picks['U'] + current_episode_picks['S'])\n","\n","    # Make update\n","    advantages = []\n","    advantages_q_1 = []\n","    advantages_q_2 = []\n","    value_diffs = []\n","    for i in range(len(ep_rewards) - 1):\n","        a_online = ep_rewards[i] + gamma * ep_state_values[i + 1].detach() - ep_state_values[i]\n","        advantages.append(a_online)\n","        a_online_q_1 = ep_rewards[i] + gamma * ep_state_values[i + 1].detach() - ep_action_state_values[i][ep_actions[i]]\n","        a_online_q_2 = ep_rewards[i] + gamma * ep_state_values[i + 1].detach() - ep_action_state_values[i][ep_alt_actions[i]]\n","        advantages_q_1.append(a_online_q_1)\n","        advantages_q_2.append(a_online_q_2)\n","\n","    advantages_stack = torch.stack(advantages)\n","    advantages_q_1_stack = torch.stack(advantages_q_1)\n","    advantages_q_2_stack = torch.stack(advantages_q_2)\n","\n","    actor_loss = (-(torch.stack(ep_action_log_probs[:-1]).squeeze() * advantages_stack.detach()).mean() - (torch.stack(ep_alt_action_log_probs[:-1]).squeeze() * advantages_stack.detach()).mean())/2\n","    critic_loss = advantages_stack.pow(2).mean()\n","    critic_q_loss = (advantages_q_1_stack.pow(2).mean() + advantages_q_2_stack.pow(2).mean())/2\n","    log_barrier_loss = torch.stack(ep_log_barriers[:-1]).mean()\n","\n","    total_rl_loss = (actor_loss + 0.5 * critic_loss + 0.5 * critic_q_loss + log_barrier_coeff * log_barrier_loss)/accumulation_steps\n","    total_rl_loss.backward()\n","\n","    if (episode + 1) % accumulation_steps == 0:\n","        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)\n","        optimizer.step()\n","        optimizer.zero_grad()\n","\n","    if (episode + 1) % 4096 == 0:\n","        torch.cuda.empty_cache()\n","\n","    if (episode + 1) % (n_episodes_pretrain // n_episodes_scale) == 0:\n","        avg_interval = n_episodes_pretrain // n_episodes_scale\n","        avg_ripe = np.mean(ripe_count_pretrain[-avg_interval:]) if ripe_count_pretrain else 0\n","        avg_unripe = np.mean(unripe_count_pretrain[-avg_interval:]) if unripe_count_pretrain else 0\n","        print(f\"Episode {episode+1}/{n_episodes_pretrain}, Avg Ripe: {avg_ripe:.2f}, Avg Unripe: {avg_unripe:.2f}, Actor: {actor_loss.item():.3f}, Critic: {critic_loss.item():.3f}, Critic_Q: {critic_q_loss.item():.3f},LB: {log_barrier_loss.item():.3f}\")"],"metadata":{"id":"wJiasPJ46jlB"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"pukXZZOIn5ht"},"outputs":[],"source":["# torch.save(model.state_dict(), gdrive_model_dir + '/pretrained_5DGN_8step_8action_1.pth')"]},{"cell_type":"code","source":["pretrain_5D_dict = torch.load(gdrive_model_dir + '/pretrained_5DGN_8step_8action_1.pth')\n","model.load_state_dict(pretrain_5D_dict)"],"metadata":{"id":"rj5fNANEaWc8","executionInfo":{"status":"ok","timestamp":1749605621544,"user_tz":420,"elapsed":23,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"f0ddb155-4275-4d3d-b337-bc18aecf94c2"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["<All keys matched successfully>"]},"metadata":{},"execution_count":27}]},{"cell_type":"code","source":["# Test q-critic head in random environments\n","model.to(device)\n","model.eval() # Set the model to evaluation mode (important for inference)\n","\n","# --- Define action names for clear output ---\n","action_names = {\n","    0: 'UP', 1: 'DOWN', 2: 'LEFT', 3: 'RIGHT',\n","    4: 'UP_ALT', 5: 'DOWN_ALT', 6: 'LEFT_ALT', 7: 'RIGHT_ALT'\n","}\n","\n","# --- Run a test episode ---\n","print(\"--- Starting Test Run in a New Random Environment ---\")\n","state = env.reset()\n","done = False\n","\n","# We will run for 4 steps to observe the agent's behavior\n","for step in range(4):\n","    if done:\n","        print(f\"Episode finished early at step {step}.\")\n","        break\n","\n","    # Render the current state of the environment\n","    env.render()\n","\n","    # Prepare the state tensor for the model\n","    obs_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n","\n","    # Get model outputs without calculating gradients\n","    with torch.no_grad():\n","        action_logits, state_value, action_state_values = model(obs_tensor)\n","\n","    # Calculate action probabilities using softmax\n","    action_probs = F.softmax(action_logits, dim=-1).squeeze()\n","\n","    # Squeeze batch dimensions from value outputs\n","    state_value = state_value.squeeze().item()\n","    action_state_values = action_state_values.squeeze()\n","\n","    # --- Print Model's \"Thoughts\" ---\n","    print(f\"Predicted State Value (V-value): {state_value:.4f}\\n\")\n","    print(\"Action         |  Probability  |  Q-Value\")\n","    print(\"------------------------------------------\")\n","    for i in range(env.action_space_n):\n","        action_name = action_names[i]\n","        prob = action_probs[i].item()\n","        q_val = action_state_values[i].item()\n","        print(f\"{action_name:<14} |  {prob:<11.4f}  |  {q_val:.4f}\")\n","    print(\"------------------------------------------\")\n","\n","    # Choose the action with the highest probability\n","    action = torch.argmax(action_probs).item()\n","    print(f\"\\n>>> Agent takes action: {action_names[action]} <<<\\n\")\n","    print(\"=\" * 45)\n","\n","\n","    # Take a step in the environment\n","    state, reward, done, info = env.step(action)"],"metadata":{"id":"1zP7Jhfe6SMe","executionInfo":{"status":"ok","timestamp":1749605621585,"user_tz":420,"elapsed":37,"user":{"displayName":"Jeremy Rubinoff","userId":"01574885304845120943"}},"outputId":"a9873d82-baf1-4e03-a85e-6e3e2b342399","colab":{"base_uri":"https://localhost:8080/"}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["--- Starting Test Run in a New Random Environment ---\n","Step: 0, Max Steps: 9\n","WWWWWWWWWWWWW\n","W.......S.S.W\n","WSUU....U...W\n","W..T...S.TS.W\n","W......S....W\n","WU.....SAUS.W\n","WWWWWWWWWWWWW\n","---------------------------------\n","Predicted State Value (V-value): 8.9542\n","\n","Action         |  Probability  |  Q-Value\n","------------------------------------------\n","UP             |  0.0083       |  8.3035\n","DOWN           |  0.0093       |  8.2796\n","LEFT           |  0.0719       |  8.6361\n","RIGHT          |  0.4463       |  8.8997\n","UP_ALT         |  0.0109       |  8.3124\n","DOWN_ALT       |  0.0102       |  8.2686\n","LEFT_ALT       |  0.0823       |  8.6230\n","RIGHT_ALT      |  0.3607       |  8.9157\n","------------------------------------------\n","\n",">>> Agent takes action: RIGHT <<<\n","\n","=============================================\n","Step: 1, Max Steps: 9\n","WWWWWWWWWWWWW\n","W.......S.S.W\n","WSUU....U...W\n","W..T...S.TS.W\n","W......S....W\n","WU.....S.AS.W\n","WWWWWWWWWWWWW\n","---------------------------------\n","Predicted State Value (V-value): 8.2367\n","\n","Action         |  Probability  |  Q-Value\n","------------------------------------------\n","UP             |  0.0133       |  7.5807\n","DOWN           |  0.0128       |  7.5338\n","LEFT           |  0.3171       |  8.1229\n","RIGHT          |  0.1633       |  7.9661\n","UP_ALT         |  0.0173       |  7.5948\n","DOWN_ALT       |  0.0136       |  7.5396\n","LEFT_ALT       |  0.3347       |  8.0874\n","RIGHT_ALT      |  0.1280       |  8.0142\n","------------------------------------------\n","\n",">>> Agent takes action: LEFT_ALT <<<\n","\n","=============================================\n","Step: 2, Max Steps: 9\n","WWWWWWWWWWWWW\n","W.......S.S.W\n","WSUU....U...W\n","W..T...S.TS.W\n","W......S....W\n","WU.....SA.S.W\n","WWWWWWWWWWWWW\n","---------------------------------\n","Predicted State Value (V-value): 8.2363\n","\n","Action         |  Probability  |  Q-Value\n","------------------------------------------\n","UP             |  0.0107       |  7.8525\n","DOWN           |  0.0052       |  7.5007\n","LEFT           |  0.4455       |  8.3055\n","RIGHT          |  0.0327       |  7.8298\n","UP_ALT         |  0.0153       |  7.8614\n","DOWN_ALT       |  0.0058       |  7.5003\n","LEFT_ALT       |  0.4604       |  8.2655\n","RIGHT_ALT      |  0.0244       |  7.8686\n","------------------------------------------\n","\n",">>> Agent takes action: LEFT_ALT <<<\n","\n","=============================================\n","Step: 3, Max Steps: 9\n","WWWWWWWWWWWWW\n","W.......S.S.W\n","WSUU....U...W\n","W..T...S.TS.W\n","W......S....W\n","WU.....A..S.W\n","WWWWWWWWWWWWW\n","---------------------------------\n","Predicted State Value (V-value): 8.1879\n","\n","Action         |  Probability  |  Q-Value\n","------------------------------------------\n","UP             |  0.4748       |  8.2437\n","DOWN           |  0.0009       |  7.2104\n","LEFT           |  0.0014       |  7.2539\n","RIGHT          |  0.0128       |  7.6419\n","UP_ALT         |  0.4900       |  8.2504\n","DOWN_ALT       |  0.0026       |  7.2330\n","LEFT_ALT       |  0.0016       |  7.2342\n","RIGHT_ALT      |  0.0158       |  7.6939\n","------------------------------------------\n","\n",">>> Agent takes action: UP_ALT <<<\n","\n","=============================================\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kmh6eGBYWmhO"},"outputs":[],"source":["AC_model = ActorCritic(env.observation_space_shape, env.action_space_n, num_hidden_layers=n_hidden, hidden_size=hidden_layer_size).to(device)\n","\n","AC_model.conv1.weight.data = model.conv1.weight.data.clone()\n","AC_model.conv1.bias.data = model.conv1.bias.data.clone()\n","AC_model.gn1.weight.data = model.gn1.weight.data.clone()\n","AC_model.gn1.bias.data = model.gn1.bias.data.clone()\n","AC_model.conv2.weight.data = model.conv2.weight.data.clone()\n","AC_model.conv2.bias.data = model.conv2.bias.data.clone()\n","AC_model.gn2.weight.data = model.gn2.weight.data.clone()\n","AC_model.gn2.bias.data = model.gn2.bias.data.clone()\n","for i, (source_hidden, target_hidden) in enumerate(zip(model.hidden_layers_fc, AC_model.hidden_layers_fc)):\n","        target_hidden.weight.data = source_hidden.weight.data.clone()\n","        target_hidden.bias.data = source_hidden.bias.data.clone()\n","for i, (source_norm, target_norm) in enumerate(zip(model.layer_norms_fc, AC_model.layer_norms_fc)):\n","        target_norm.weight.data = source_norm.weight.data.clone()\n","        target_norm.bias.data = source_norm.bias.data.clone()\n","AC_model.actor.weight.data = model.actor.weight.data.clone()\n","AC_model.actor.bias.data = model.actor.bias.data.clone()\n","AC_model.critic.weight.data = model.critic.weight.data.clone()\n","AC_model.critic.bias.data = model.critic.bias.data.clone()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8jVYjq7ZWW48"},"outputs":[],"source":["torch.save(AC_model.state_dict(), gdrive_model_dir + '/pretrain_5DGN_8step_8action_AC.pth')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"xuHxoPO4WtIc"},"outputs":[],"source":["QAC_model = QAC(env.observation_space_shape, env.action_space_n, num_hidden_layers=n_hidden, hidden_size=hidden_layer_size).to(device)\n","\n","QAC_model.conv1.weight.data = model.conv1.weight.data.clone()\n","QAC_model.conv1.bias.data = model.conv1.bias.data.clone()\n","QAC_model.gn1.weight.data = model.gn1.weight.data.clone()\n","QAC_model.gn1.bias.data = model.gn1.bias.data.clone()\n","QAC_model.conv2.weight.data = model.conv2.weight.data.clone()\n","QAC_model.conv2.bias.data = model.conv2.bias.data.clone()\n","QAC_model.gn2.weight.data = model.gn2.weight.data.clone()\n","QAC_model.gn2.bias.data = model.gn2.bias.data.clone()\n","for i, (source_hidden, target_hidden) in enumerate(zip(model.hidden_layers_fc, QAC_model.hidden_layers_fc)):\n","        target_hidden.weight.data = source_hidden.weight.data.clone()\n","        target_hidden.bias.data = source_hidden.bias.data.clone()\n","for i, (source_norm, target_norm) in enumerate(zip(model.layer_norms_fc, QAC_model.layer_norms_fc)):\n","        target_norm.weight.data = source_norm.weight.data.clone()\n","        target_norm.bias.data = source_norm.bias.data.clone()\n","QAC_model.actor.weight.data = model.actor.weight.data.clone()\n","QAC_model.actor.bias.data = model.actor.bias.data.clone()\n","QAC_model.critic_q.weight.data = model.critic_q.weight.data.clone()\n","QAC_model.critic_q.bias.data = model.critic_q.bias.data.clone()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8WmUzqlzX_xb"},"outputs":[],"source":["torch.save(QAC_model.state_dict(), gdrive_model_dir + '/pretrain_5DGN_8step_8action_QAC.pth')"]}],"metadata":{"colab":{"machine_shape":"hm","provenance":[{"file_id":"1_lK49JWkIoLrTDu9_ouIuAg19iBGISQg","timestamp":1748984032789},{"file_id":"1G-diECB3G7LEW-tgokMUpvR9ITZJ69L2","timestamp":1748893749057},{"file_id":"1J1NLBR5HGrQYnPjR_ZDGoN2CJKGjnOSK","timestamp":1747900001873}],"authorship_tag":"ABX9TyNrSjo9RvHKjeGoBbp9t7xs"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}