{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {
    "id": "xNdF5r8H8F_e",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {
    "executionInfo": {
     "elapsed": 46,
     "status": "ok",
     "timestamp": 1753529019710,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "EUXSTQOd_Ajw"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import gymnasium as gym\n",
    "from gymnasium.envs.registration import register\n",
    "import hashlib\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import huggingface_hub\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from transformers import DynamicCache\n",
    "import itertools\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from matplotlib.colors import LogNorm\n",
    "from collections import deque\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {
    "id": "yM7_l-DSGAXn"
   },
   "outputs": [],
   "source": [
    "baseline_trials = 1000\n",
    "meta_RL_trials = 100\n",
    "length = 30\n",
    "max_context_length = 4096\n",
    "model_path=\"meta-llama/Llama-3.2-3B\"\n",
    "adapter_path=\"username/path_to_trained_adaptor\"\n",
    "env_seed=None\n",
    "action_space=[\"left\", \"right\", \"up\", \"down\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {
    "id": "XeU7Fz8L8CwF",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {
    "executionInfo": {
     "elapsed": 30,
     "status": "ok",
     "timestamp": 1753528979771,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "euOHwHF29NcT"
   },
   "outputs": [],
   "source": [
    "class Environment:\n",
    "\n",
    "    def __init__(self, make_kwargs, random_action_prob=0.0):\n",
    "        self.make_kwargs = make_kwargs\n",
    "        self.random_action_prob = random_action_prob\n",
    "\n",
    "        self.env = gym.make(**self.make_kwargs)\n",
    "\n",
    "        self.done = True\n",
    "        self.observations = []\n",
    "        self.rewards = []\n",
    "\n",
    "        self.record = False\n",
    "        self.video = []\n",
    "        self.text = []\n",
    "\n",
    "        self.action_map = {}\n",
    "\n",
    "    def reset(self, record=False, **kwargs):\n",
    "        self.record = record\n",
    "        observation, _ = self.env.reset(**kwargs)\n",
    "        observation_mapped = self.map_observation(observation)\n",
    "        self.done = False\n",
    "\n",
    "        self.observations = []\n",
    "        self.rewards = []\n",
    "        self.observations.append(observation_mapped)\n",
    "\n",
    "        if self.record:\n",
    "            self.video = []\n",
    "            self.text = []\n",
    "            self.video.append(self.env.render())\n",
    "            self.text.append(f\"observation {observation_mapped}\")\n",
    "\n",
    "        return observation_mapped\n",
    "\n",
    "    def step(self, action, **kwargs):\n",
    "        if self.is_action_valid(action):\n",
    "            action_mapped = self.map_action(action)\n",
    "            random_action = self.random_action_prob > random.random()\n",
    "            action_taken = self.env.action_space.sample() if random_action else action_mapped\n",
    "            observation, reward, terminated, truncated, _ = self.env.step(action_taken, **kwargs)\n",
    "            observation_mapped = self.map_observation(observation)\n",
    "            if random_action:\n",
    "                observation_mapped = observation_mapped + '*'\n",
    "            reward_mapped = self.map_reward(reward)\n",
    "        else:\n",
    "            observation_mapped, reward_mapped, terminated, truncated = self.invalid_action_result(action)\n",
    "\n",
    "        self.done = terminated or truncated\n",
    "\n",
    "        self.observations.append(observation_mapped)\n",
    "        self.rewards.append(reward_mapped)\n",
    "\n",
    "        if self.record:\n",
    "            self.video.append(self.env.render())\n",
    "            self.text.append(f\"action {action} observation {observation_mapped} reward {reward_mapped} terminated {terminated} truncated {truncated}\")\n",
    "\n",
    "        return observation_mapped, reward_mapped, terminated, truncated\n",
    "\n",
    "    def is_done(self):\n",
    "        return self.done\n",
    "\n",
    "    def map_action(self, action):\n",
    "        return self.action_map[action]\n",
    "\n",
    "    def map_reward(self, reward):\n",
    "        return float(reward)\n",
    "\n",
    "    def map_observation(self, observation):\n",
    "        return str(observation)\n",
    "\n",
    "    def is_action_valid(self, action):\n",
    "        return action in self.action_map.keys()\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "    @staticmethod\n",
    "    def pick_from_dict(map_dict, key):\n",
    "        try:\n",
    "            value = map_dict[key]\n",
    "        except KeyError:\n",
    "            key_list = list(map_dict.keys())\n",
    "            hash_bytes = hashlib.sha256(key.encode(\"utf-8\")).digest()\n",
    "            hash_int = int.from_bytes(hash_bytes, byteorder=\"big\")\n",
    "            index = hash_int % len(key_list)\n",
    "            corrected_key = key_list[index]\n",
    "            value = map_dict[corrected_key]\n",
    "        return value\n",
    "\n",
    "    def sample_action(self):\n",
    "        return random.choice(list(self.action_map.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {
    "executionInfo": {
     "elapsed": 47,
     "status": "ok",
     "timestamp": 1753528979814,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "DbgmSZtq9VA0"
   },
   "outputs": [],
   "source": [
    "class FrozenLake(Environment):\n",
    "\n",
    "    def __init__(self, max_episode_steps=100, is_slippery=False, gridmap=None, gridmap_kwargs=None, random_action_prob=0.0, seed=None):\n",
    "        if gridmap is not None:\n",
    "            self.gridmap = gridmap\n",
    "        elif gridmap_kwargs is not None:\n",
    "            self.gridmap = self.generate_unique_map(seed=seed, **gridmap_kwargs)\n",
    "        else:\n",
    "            raise ValueError(\"gridmap or gridmap_kwargs must be given\")\n",
    "        make_kwargs = {\n",
    "            \"id\": \"FrozenLake-v1\",\n",
    "            \"desc\": self.gridmap,\n",
    "            \"is_slippery\": is_slippery,\n",
    "            \"render_mode\": \"rgb_array\",\n",
    "            \"max_episode_steps\": max_episode_steps,\n",
    "        }\n",
    "        super().__init__(make_kwargs, random_action_prob=random_action_prob)\n",
    "\n",
    "        self.action_map = {\n",
    "            \"left\": 0,\n",
    "            \"down\": 1,\n",
    "            \"right\": 2,\n",
    "            \"up\": 3,\n",
    "        }\n",
    "\n",
    "    def map_reward(self, reward):\n",
    "        if reward == 0.0:\n",
    "            return None\n",
    "        new_reward = float(reward / (len(self.rewards) + 1))\n",
    "        return new_reward\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "    @staticmethod\n",
    "    def map_is_valid(gridmap, min_hops):\n",
    "        rows = len(gridmap)\n",
    "        cols = len(gridmap[0])\n",
    "        board = [list(row) for row in gridmap]\n",
    "    \n",
    "        found_start = False\n",
    "        for r in range(rows):\n",
    "            for c in range(cols):\n",
    "                if board[r][c] == \"S\":\n",
    "                    found_start = True\n",
    "                    state = r * cols + c\n",
    "                    result = FrozenLake.find_path_to_goal(gridmap, state)\n",
    "                    if result is None:\n",
    "                        return False  # No path to goal from this 'S'\n",
    "                    _, actions = result\n",
    "                    if len(actions) < min_hops:\n",
    "                        return False  # Path exists but is too short\n",
    "        return found_start  # True if at least one 'S', otherwise False\n",
    "\n",
    "    @staticmethod\n",
    "    def generate_map(min_width=4, max_width=4, min_height=4, max_height=4, hole_prob=0.7, start_pos=0, start_pos_prob=None, goal_pos=15, goal_pos_prob=None, rng=None, seed=None):\n",
    "        if rng is None:\n",
    "            rng = random.Random(seed)\n",
    "\n",
    "        width = rng.randint(min_width, max_width)\n",
    "        height = rng.randint(min_height, max_height)\n",
    "\n",
    "        map_index = ['F'] * (width * height)\n",
    "        avalible_index = list(range(width * height))\n",
    "\n",
    "        if type(start_pos) is int:\n",
    "            start_pos = [start_pos]\n",
    "\n",
    "        if start_pos is None and start_pos_prob is None:\n",
    "            start_pos = [rng.choice(avalible_index)]\n",
    "        elif start_pos is None and start_pos_prob is not None:\n",
    "            start_pos = []\n",
    "            for i in avalible_index.copy():\n",
    "                if rng.random() < start_pos_prob:\n",
    "                    start_pos.append(i)\n",
    "\n",
    "        for p in start_pos:\n",
    "            map_index[p] = 'S'\n",
    "            avalible_index.remove(p)\n",
    "\n",
    "        if type(goal_pos) is int:\n",
    "            goal_pos = [goal_pos]\n",
    "\n",
    "        if goal_pos is None and goal_pos_prob is None:\n",
    "            goal_pos = [rng.choice(avalible_index)]\n",
    "        elif goal_pos is None and goal_pos_prob is not None:\n",
    "            goal_pos = []\n",
    "            for i in avalible_index.copy():\n",
    "                if rng.random() < goal_pos_prob:\n",
    "                    goal_pos.append(i)\n",
    "\n",
    "        for p in goal_pos:\n",
    "            map_index[p] = 'G'\n",
    "            avalible_index.remove(p)\n",
    "\n",
    "        for i in avalible_index.copy():\n",
    "            if rng.random() < hole_prob:\n",
    "                map_index[i] = 'H'\n",
    "                avalible_index.remove(i)\n",
    "\n",
    "        map = []\n",
    "        for i in range(height):\n",
    "            row = ''.join(map_index[i*width:(i+1)*width])\n",
    "            map.append(row)\n",
    "\n",
    "        return map\n",
    "\n",
    "    @staticmethod\n",
    "    def generate_valid_map(min_hops=0, seed=None, **kwargs):\n",
    "        rng = random.Random(seed)\n",
    "\n",
    "        while True:\n",
    "            map = FrozenLake.generate_map(rng=rng, seed=None, **kwargs)\n",
    "            if FrozenLake.map_is_valid(map, min_hops):\n",
    "                return map\n",
    "\n",
    "    @staticmethod\n",
    "    def generate_unique_map(other_gridmaps=[], **kwargs):\n",
    "        while True:\n",
    "            gridmap = FrozenLake.generate_valid_map(**kwargs)\n",
    "            if gridmap not in other_gridmaps:\n",
    "                return gridmap\n",
    "\n",
    "    @staticmethod\n",
    "    def find_path_to_goal(gridmap, state):\n",
    "        rows = len(gridmap)\n",
    "        cols = len(gridmap[0])\n",
    "        board = [list(row) for row in gridmap]\n",
    "    \n",
    "        start_r = state // cols\n",
    "        start_c = state % cols\n",
    "        start_pos = (start_r, start_c)\n",
    "    \n",
    "        # Find goal positions\n",
    "        goals = [(i, j) for i in range(rows) for j in range(cols) if board[i][j] == \"G\"]\n",
    "        if not goals:\n",
    "            return None\n",
    "    \n",
    "        # Ordered directions to match action_map indices\n",
    "        directions = [\n",
    "            (0, -1),  # left  -> 0\n",
    "            (1, 0),   # down  -> 1\n",
    "            (0, 1),   # right -> 2\n",
    "            (-1, 0),  # up    -> 3\n",
    "        ]\n",
    "    \n",
    "        queue = deque()\n",
    "        queue.append((start_pos, [], []))  # (current position, path_so_far, actions_so_far)\n",
    "        visited = set()\n",
    "    \n",
    "        while queue:\n",
    "            (r, c), path, actions = queue.popleft()\n",
    "            if (r, c) in goals:\n",
    "                full_path = [start_pos] + path\n",
    "                return full_path, actions\n",
    "            if (r, c) in visited:\n",
    "                continue\n",
    "            visited.add((r, c))\n",
    "            for action, (dr, dc) in enumerate(directions):\n",
    "                nr, nc = r + dr, c + dc\n",
    "                if 0 <= nr < rows and 0 <= nc < cols and board[nr][nc] != \"H\":\n",
    "                    queue.append(\n",
    "                        ((nr, nc), path + [(nr, nc)], actions + [action])\n",
    "                    )\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1753528979837,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "1QEyqQ8Y9dbV"
   },
   "outputs": [],
   "source": [
    "func = partial(\n",
    "    FrozenLake,\n",
    "    max_episode_steps=50,\n",
    "    is_slippery=False,\n",
    "    gridmap_kwargs={\n",
    "        \"other_gridmaps\": [],\n",
    "        \"min_width\": 3,\n",
    "        \"max_width\": 5,\n",
    "        \"min_height\": 3,\n",
    "        \"max_height\": 5,\n",
    "        \"hole_prob\": 0.2,\n",
    "        \"start_pos\": None,\n",
    "        \"start_pos_prob\": 0.1,\n",
    "        \"goal_pos\": None,\n",
    "        \"goal_pos_prob\": 0.1,\n",
    "        \"min_hops\": 4,\n",
    "    },\n",
    "    random_action_prob=0.0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {
    "id": "Ff7I5RPB7-DQ",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1753528979863,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "HOMAxsBa70hA"
   },
   "outputs": [],
   "source": [
    "class NumpyBuffer:\n",
    "    def __init__(self, dtype, size_increment):\n",
    "        self.dtype = dtype\n",
    "        self.size_increment = size_increment\n",
    "        self.storage = np.empty(self.size_increment, dtype=self.dtype)\n",
    "        self.index = 0\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.index\n",
    "\n",
    "    def increase_size(self):\n",
    "        new_size = self.storage.shape[0] + self.size_increment\n",
    "        new_storage = np.empty(new_size, dtype=self.dtype)\n",
    "        new_storage[:self.index] = self.storage[:self.index]\n",
    "        self.storage = new_storage\n",
    "\n",
    "    def add(self, data):\n",
    "        data = np.atleast_1d(data)\n",
    "        assert data.dtype == self.storage.dtype\n",
    "        assert len(data.shape) == 1\n",
    "        length = data.shape[0]\n",
    "        new_index = self.index + length\n",
    "        while self.storage.shape[0] < new_index:\n",
    "            self.increase_size()\n",
    "        self.storage[self.index:new_index] = data\n",
    "        self.index = new_index\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.storage[:self.index][idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {
    "editable": true,
    "executionInfo": {
     "elapsed": 40,
     "status": "ok",
     "timestamp": 1753528979891,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "pOX0bjwmmKEd",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Stream:\n",
    "    def __init__(self, tokenizer, size_increment):\n",
    "        self.struct = []\n",
    "        self.token_dtype = np.dtype([\n",
    "            (\"ids\", np.int64),\n",
    "            (\"types\", np.int64),\n",
    "            (\"episodes\", np.int64),\n",
    "            (\"groups\", np.int64),\n",
    "            (\"rewards\", np.float64),\n",
    "            (\"elements\", np.int64),\n",
    "        ])\n",
    "        self.tokens = NumpyBuffer(dtype=self.token_dtype, size_increment=size_increment)\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "        self.roles = {\n",
    "            \"environment\": {\n",
    "                \"header\": \"environment\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"begin\": {\n",
    "                \"header\": \"begin\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"thought\": {\n",
    "                \"header\": \"thought:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"observation\": {\n",
    "                \"header\": \"observation:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"action\": {\n",
    "                \"header\": \"action:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"reward\": {\n",
    "                \"header\": \"reward:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"end\": {\n",
    "                \"header\": \"end\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "        }\n",
    "\n",
    "        # Precompute header and footer token IDs for each role\n",
    "        type_id = 0\n",
    "        for role_name, role in self.roles.items():\n",
    "            role[\"header_ids\"] = self.tokenizer.encode(role[\"header\"], add_special_tokens=False)\n",
    "            role[\"footer_ids\"] = self.tokenizer.encode(role[\"footer\"], add_special_tokens=False)\n",
    "            role[\"header_type\"] = type_id\n",
    "            role[\"content_type\"] = type_id + 1\n",
    "            role[\"footer_type\"] = type_id + 2\n",
    "            type_id += 3\n",
    "\n",
    "    def add_struct(self, struct):\n",
    "        tokens = self.struct_to_tokens(struct)\n",
    "        struct_length = len(self.struct)\n",
    "        tokens[\"elements\"] = [x + struct_length for x in tokens[\"elements\"]]\n",
    "        self.struct.extend(struct)\n",
    "        self.tokens.add(tokens)\n",
    "\n",
    "    def __getitem__(self, key):\n",
    "        if isinstance(key, tuple):\n",
    "            idx, mode = key\n",
    "            if mode == \"struct\":\n",
    "                return self.struct[idx]\n",
    "            elif mode == \"tokens\":\n",
    "                return self.tokens[idx]\n",
    "        else:\n",
    "            return self.struct[key]\n",
    "\n",
    "    def get_tokens_length(self):\n",
    "        return len(self.tokens)\n",
    "\n",
    "    def tokenize(self, content):\n",
    "        return self.tokenizer.encode(content, add_special_tokens=False)\n",
    "\n",
    "    def detokenize(self, content_ids):\n",
    "        return self.tokenizer.decode(content_ids, skip_special_tokens=False)\n",
    "\n",
    "    def struct_to_tokens(self, struct):\n",
    "        ids = []\n",
    "        types = []\n",
    "        episodes = []\n",
    "        groups = []\n",
    "        rewards = []\n",
    "        elements = []\n",
    "\n",
    "        for i, seg in enumerate(struct):\n",
    "            role_name = seg[\"role\"]\n",
    "\n",
    "            try:\n",
    "                role = self.roles[role_name]\n",
    "            except KeyError:\n",
    "                raise KeyError(f\"Unknown role: {role_name}\")\n",
    "\n",
    "            content_ids = self.tokenize(str(seg[\"content\"]))\n",
    "            full_ids = role[\"header_ids\"] + content_ids + role[\"footer_ids\"]\n",
    "            ids.extend(full_ids)\n",
    "            types.extend(\n",
    "                [role[\"header_type\"]] * len(role[\"header_ids\"])\n",
    "                + [role[\"content_type\"]] * len(content_ids)\n",
    "                + [role[\"footer_type\"]] * len(role[\"footer_ids\"])\n",
    "            )\n",
    "            episodes.extend([int(seg[\"episode\"])] * len(full_ids))\n",
    "            groups.extend([int(seg[\"group\"])] * len(full_ids))\n",
    "            rewards.extend([float(seg[\"reward\"])] + [0.0] * (len(full_ids) - 1))\n",
    "            elements.extend([i] * len(full_ids))\n",
    "\n",
    "        tokens = np.zeros(len(ids), dtype=self.token_dtype)\n",
    "        tokens['ids'] = ids\n",
    "        tokens['types'] = types\n",
    "        tokens['episodes'] = episodes\n",
    "        tokens['groups'] = groups\n",
    "        tokens['rewards'] = rewards\n",
    "        tokens['elements'] = elements\n",
    "\n",
    "        return tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {
    "id": "eF0w5k4_8NBc",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1753528979902,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "g4QTS1fnGy0W"
   },
   "outputs": [],
   "source": [
    "def run_random(env_func, length, env_seed):\n",
    "    env = env_func(seed=env_seed)\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    observations = []\n",
    "\n",
    "    for t in range(length):\n",
    "        assert env.is_done()\n",
    "        while True:\n",
    "            if env.is_done():\n",
    "                obs = env.reset()\n",
    "                observations.append([obs])\n",
    "                actions.append([])\n",
    "                rewards.append([])\n",
    "    \n",
    "            action = str(rng.choice(list(env.action_map.keys())))\n",
    "    \n",
    "            actions[-1].append(action)\n",
    "            obs, reward, terminated, truncated = env.step(action)\n",
    "            rewards[-1].append(reward)\n",
    "            observations[-1].append(obs)\n",
    "\n",
    "            if terminated or truncated:\n",
    "                break\n",
    "\n",
    "    return rewards, actions, observations, env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1753528980120,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "3aSjJ_bs-8eL"
   },
   "outputs": [],
   "source": [
    "def run_max(env_func, length, env_seed):\n",
    "    env = env_func(seed=env_seed)\n",
    "    inv_action_map = {v: k for k, v in env.action_map.items()}\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    observations = []\n",
    "\n",
    "    obs = \"\"\n",
    "    for t in range(length):\n",
    "        assert env.is_done()\n",
    "        while True:\n",
    "            if env.is_done():\n",
    "                obs = env.reset()\n",
    "                observations.append([obs])\n",
    "                actions.append([])\n",
    "                rewards.append([])\n",
    "    \n",
    "            _, path_actions = env.find_path_to_goal(env.gridmap, int(obs))\n",
    "            action_int = path_actions[0]\n",
    "            action = str(inv_action_map[action_int])\n",
    "    \n",
    "            actions[-1].append(action)\n",
    "            obs, reward, terminated, truncated = env.step(action)\n",
    "            rewards[-1].append(reward)\n",
    "            observations[-1].append(obs)\n",
    "\n",
    "            if terminated or truncated:\n",
    "                break\n",
    "\n",
    "    return rewards, actions, observations, env"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {
    "id": "XFoJ96Ah8cYh",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Meta-RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 591,
     "referenced_widgets": [
      "cdc43c5e72314048b9ba42921bdae13f",
      "0ce586a5a7794d25a896c715dcb6d2c5",
      "a510a866c9a94d89b78592826099ae12",
      "95420d7a4457429e812ebd8df713982b",
      "bf2f75a0fdab459e98bedea30d01d27a",
      "f61aacc262714418ab0f275249c12da5",
      "49bfd14afa40429cbf7f452a851a15b4",
      "c0eca66030f346f497cab28fb3d9bb37",
      "cd66fab1e03d4ba7b096e5d7cba960d6",
      "114b740337db498c9aef61098ec8ff61",
      "227777048e054cbd8f13638fda2c7442",
      "d4f3671a3b44440982dc46472a2da12b",
      "873f20a82e2340429cea2f1a48b4bf54",
      "b96bf7d5e922471db3d49f0acda1a13b",
      "2b00fa2d36334dc7a7fc296db0c578c4",
      "0517378f3e8a4dce9cd040a7289c15c0",
      "65b5fc640bee45e7a6c86719eba893b2",
      "66b3269f36e3401cb84926ceab65ef50",
      "a83460b1d4004b149ce407df5c05c7c4",
      "567023535c2b4a8da26601689eb8e19f",
      "56a6407352414743b123341423c78d48",
      "33758e4a1df34c588f334ea4cf01a678",
      "d6019ffc0b284ca8843bb5a3d2cc53d7",
      "f343031ef359460daf8fed0968b4fe4a",
      "2bfb50d38dca405b91d73cd682dd8be8",
      "5af0e713bc4c41a092e76969c0560557",
      "d928085a1fef4c3fada04007caf7724a",
      "0b66e67f15634e7e9e29bd36629fce58",
      "2d548b5cf67e4690a7294e13e0d75696",
      "1dec6971d41b41cf91a2d36793dfa6a6",
      "e0f24675e6ab43d68029472745699a1f",
      "138fabf1598f4dadb5bb8889d9cce168",
      "16723cbe1a064467ac367cf30aa47861",
      "980ace0a44ec4464b52f8f2947b99fc2",
      "2c1791b0e5904bcba832e78d6bbe7ab6",
      "4fe4ae6e9b7a4cb9974b25f0c353c7ed",
      "99cda24ade694ceab00fa544b63421f3",
      "055f9fe449d84304b2fd88efe0fe39fc",
      "70c0cae9a8a246bf8054f00ddeff5c2b",
      "b45b9c5178994841884cf5d23d8b0b33",
      "b18d9c143214465a98b82f80db99c8b0",
      "76b57e0ab63f4adb95b516e6a8b31595",
      "51bd48cf4d0941bd8106256a29e8a723",
      "bceb2586f91741559c21e34bfe7a325c",
      "3df732e64b0049cc9a26a6e30d03ae02",
      "11753fc2407d473da19373a22c5d2fcc",
      "ce402a1cc4b440b69a322467eedc303c",
      "3b4d285f89e34894a472817d52c8aab4",
      "aa308070bbe34f45a2732543ceb86a76",
      "45ab9852986f4e38a9e3ba15669db79b",
      "16afce42d19645c890f3038d2e0e17ff",
      "6b63270a429840c59060e9a1fda10ef5",
      "4f0ad5654e6648ffaed10d6738268f7d",
      "20ae8cb2168a483b8e30370a97308a61",
      "ba96312e1e6f493cad6d010ed5e7e313",
      "eac45d7bb3544e65899969e868c53398",
      "0cb4ddfbbbf045d1940d35aa51cc7ccb",
      "a8f1eff795054b27b41b2aa1a66ed281",
      "311edc4861254d05bc7a00243e12dd27",
      "3cf973a1184640248c7f70e6a31fe87f",
      "3f24896064224740a7df3c2289f69a30",
      "b20b1de2b0a54f23808915d60da1f8f2",
      "d47065e2ba3c47e2bfead8659703219a",
      "be93a76c1d7c4e039fa1f4ccb9acb853",
      "4770842163bc4e9f893c588e1980a7ca",
      "adf13c2c1f79408a99edb5decd4d369f",
      "4e41ea515fb54a08bf552734543e33e9",
      "9192b7b54ad44579ac0a813bbc83275d",
      "2afc78dea0c64ce2b2b69252b91ca21e",
      "2f3e765468bc4b7a95b5d584ae7d92a1",
      "4ee0aba64d444b1b8adaaf2ebc25de36",
      "10ab9e82a71a4b438dfc8a5cd024308d",
      "926024d4ecdc47a0b1cf8d26832dc6df",
      "d6e542bdf9fd460da6f980dafc491774",
      "a14f7e4ac4d54edd8ab15f7701908f84",
      "127ed116c77c45bdbe57d874ad5b7536",
      "67b51117b6404189a66a319ccecf6f6b",
      "bacc36d3bebf4541939903097e4bf9f9",
      "122a982e414840a3bf2fc392e7b60fdf",
      "2d759cdeed264a2c93339506bd490ed1",
      "221971a1808d4862bfc2f7b37ef100c4",
      "502702bf74ea48f099f464ec864d2ed5",
      "224917b35cc54e20865abd5df5e10589",
      "2de17e43c3464b248d568ca3d4af27bf",
      "a0af209cfe2744acb523c48e14454732",
      "a806d6a828814303b715a02684d8eddb",
      "4679d80ec11b4d9890fab8d7ba58f2f1",
      "94a3c57ded6b4606894aec25ea41c871",
      "c88e32c1017a4e38ba5c2665f9c089b9",
      "dd07226be92e47b9b0627c365ed80e60",
      "966ac06517294959a5a43a7f3b56f2c6",
      "3c619bce5e1c4f63a7aad87283928e1e",
      "95bb28cd843e4e67ae5ccd008954a2b9",
      "6bcf7370896b494798040a37b4551ed1",
      "7a533b32014b45779edfa351125ecb0f",
      "86960950865f4b00acabf72da49b4b15",
      "74175c03e98a4f16b5f8ca62c6fadef0",
      "6dcc34143ec74363bc28fb58e5d2be5c",
      "6389886ddbfe43f8af1a5265932c271a",
      "3c8d45369080484cbc1c690fd8461a55",
      "0ab1e351d48e4eeb9166c9a47e7b35de",
      "31a88b8a083d49f78ba9fabd71a50903",
      "de66e1dfd4c24dc9af8a217bb6c60182",
      "c01012c58bf94a4cb9c2b243db79b431",
      "ad27e16c5d1045038f6d6e497b5666e4",
      "7fae8bd9bb1e41858c7b7935e4915222",
      "8d79dd3051ff44c7a1807dc12308495e",
      "e7e7da976d3140abac9b555bf646957b",
      "1322202ec40a4e5a830cd1ae08814e32",
      "50ef1906bafd4b83aaf423f7ccca3270",
      "faa402ec7acb4e0c836ce8626cbfa040",
      "a8efe201a76f44d9b47a88b8f1656f65",
      "5babb4e863974e1eb326c781ac1f4124",
      "5c0ef002797b449197053360a6817ead",
      "8f27048aee814180accce254566a1b5f",
      "d94798644b9c4adfb0868168856507fb"
     ]
    },
    "id": "cBdGEQf1eMgq",
    "outputId": "83c5d0c7-2e38-40f3-9795-9c905e22f66e"
   },
   "outputs": [],
   "source": [
    "load_dotenv()\n",
    "huggingface_hub.login(token=os.environ.get('_HF_TOKEN'))\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device {device}\")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_path,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\",\n",
    "    cache_dir=\"../model/\",\n",
    ")\n",
    "base_model.lm_head = torch.nn.Linear(base_model.lm_head.in_features, base_model.lm_head.out_features, bias=True, dtype=base_model.lm_head.weight.dtype)\n",
    "\n",
    "if adapter_path is not None:\n",
    "    model = PeftModel.from_pretrained(base_model, adapter_path, force_download=True).merge_and_unload().to(device)\n",
    "\n",
    "model.eval()\n",
    "\n",
    "#print(model)\n",
    "\n",
    "action_ids = []\n",
    "for k in action_space:\n",
    "    tokens = tokenizer.encode(k, add_special_tokens=False)\n",
    "    if len(tokens) != 1:\n",
    "        raise ValueError(f\"Action '{k}' tokenized into {len(tokens)} tokens: {tokens}. Max length is one.\")\n",
    "    action_ids.append(tokens[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {
    "id": "ZfGsfSoKetNx"
   },
   "outputs": [],
   "source": [
    "def run_with_stream(env_func, length, stream, env_seed, action_func, random_length):\n",
    "    env = env_func(seed=env_seed)\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    observations = []\n",
    "\n",
    "    group = 0\n",
    "\n",
    "    stream.add_struct([{\"role\": \"environment\", \"content\": \"\", \"episode\": 1, \"group\": group, \"reward\": 0.0}])\n",
    "\n",
    "    t = 0\n",
    "    while ((rng.uniform() >= (1/length)) and random_length) or (t < length):\n",
    "        while True:\n",
    "            if env.is_done():\n",
    "                obs = env.reset()\n",
    "                observations.append([obs])\n",
    "                actions.append([])\n",
    "                rewards.append([])\n",
    "\n",
    "                stream.add_struct([{\"role\": \"begin\", \"content\": \"\", \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "                if obs is not None:\n",
    "                    stream.add_struct([{\"role\": \"observation\", \"content\": obs, \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "                \n",
    "            action = action_func(stream, env)\n",
    "\n",
    "            stream.add_struct([{\"role\": \"action\", \"content\": action, \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "    \n",
    "            actions[-1].append(action)\n",
    "            obs, reward, terminated, truncated = env.step(action)\n",
    "            rewards[-1].append(reward)\n",
    "            observations[-1].append(obs)\n",
    "\n",
    "            if reward is not None:\n",
    "                stream.add_struct([{\"role\": \"reward\", \"content\": reward, \"episode\": t+1, \"group\": group, \"reward\": reward}])\n",
    "            if obs is not None:\n",
    "                stream.add_struct([{\"role\": \"observation\", \"content\": obs, \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "            if terminated or truncated:\n",
    "                stream.add_struct([{\"role\": \"end\", \"content\": \"\", \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "\n",
    "            if terminated or truncated:\n",
    "                break\n",
    "\n",
    "        t += 1\n",
    "\n",
    "    return rewards, actions, observations, env\n",
    "\n",
    "def action_func_transformer(stream, env, cache, model, device, action_ids):\n",
    "    prefix_ids = torch.tensor(stream.roles[\"action\"][\"header_ids\"], dtype=torch.long, device=device).unsqueeze(0)\n",
    "    \n",
    "    stream_length = stream.get_tokens_length()\n",
    "    target_length = min(stream_length, max_context_length - prefix_ids.shape[1])\n",
    "    d = np.stack([stream[-target_length:, \"tokens\"]])\n",
    "    data = {name: torch.from_numpy(d[name]).to(device) for name in d.dtype.names}\n",
    "\n",
    "    context_ids = data[\"ids\"]   # [1, seq_len]\n",
    "    input_ids = torch.cat([context_ids, prefix_ids], dim=-1)  # [1, T+prefix]\n",
    "\n",
    "    input_length = stream_length - cache[\"alignment_time\"]\n",
    "    input_length = min(input_length, input_ids.shape[1])\n",
    "    input_ids_partial = input_ids[:, -input_length:]\n",
    "    \n",
    "    #max_cache_length = max_context_length - input_length\n",
    "    #if cache[\"past_key_values\"].get_seq_length() > max_cache_length:\n",
    "        #c = cache[\"past_key_values\"]\n",
    "        #for layer in c.layers:\n",
    "            #layer.keys   = layer.keys[..., -max_cache_length:, :]\n",
    "            #layer.values = layer.values[..., -max_cache_length:, :]\n",
    "        \n",
    "    #cache_position = torch.arange(0, input_length, dtype=torch.long, device=device) + cache[\"cache_position\"]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=input_ids_partial,\n",
    "            past_key_values=cache[\"past_key_values\"],\n",
    "            #cache_position=cache_position,\n",
    "            use_cache=True,\n",
    "        )\n",
    "    logits = outputs.logits\n",
    "    cache[\"alignment_time\"] = stream_length\n",
    "    #cache[\"cache_position\"] = cache[\"cache_position\"] + input_length\n",
    "    \n",
    "    action_logits = logits[:, -1, action_ids]\n",
    "    action_idx = torch.argmax(action_logits, dim=-1)\n",
    "    action_token = action_ids[action_idx]\n",
    "    action = stream.detokenize([action_token])\n",
    "    return action\n",
    "\n",
    "def action_func_transformer_wo(stream, env, model, device, action_ids):\n",
    "    prefix_ids = torch.tensor(stream.roles[\"action\"][\"header_ids\"], dtype=torch.long, device=device).unsqueeze(0)\n",
    "    \n",
    "    stream_length = stream.get_tokens_length()\n",
    "    target_length = min(stream_length, max_context_length - prefix_ids.shape[1])\n",
    "    d = np.stack([stream[-target_length:, \"tokens\"]])\n",
    "    data = {name: torch.from_numpy(d[name]).to(device) for name in d.dtype.names}\n",
    "    \n",
    "    context_ids = data[\"ids\"]   # [1, seq_len]\n",
    "    input_ids = torch.cat([context_ids, prefix_ids], dim=-1)  # [1, T+prefix]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=input_ids,\n",
    "        )\n",
    "    logits = outputs.logits\n",
    "    \n",
    "    action_logits = logits[:, -1, action_ids]\n",
    "    action_idx = torch.argmax(action_logits, dim=-1)\n",
    "    action_token = action_ids[action_idx]\n",
    "    action = stream.detokenize([action_token])\n",
    "    return action\n",
    "\n",
    "def action_func_random(stream, env, rng):\n",
    "    action = str(rng.choice(list(env.action_map.keys())))\n",
    "    return action\n",
    "\n",
    "def run_meta_RL(env_func, length, model, tokenizer, device, action_ids, env_seed):\n",
    "    stream = Stream(tokenizer=tokenizer, size_increment=max_context_length)\n",
    "    cache = {\"past_key_values\": DynamicCache(), \"alignment_time\": stream.get_tokens_length(), \"cache_position\": 0}\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "\n",
    "    action_func = partial(action_func_random, rng=rng)\n",
    "\n",
    "    while stream.get_tokens_length() < max_context_length:\n",
    "        _, _, _, _ = run_with_stream(env_func, length, stream, None, action_func, True)\n",
    "\n",
    "    action_func = partial(action_func_transformer, cache=cache, model=model, device=device, action_ids=action_ids)\n",
    "    #action_func = partial(action_func_transformer_wo, model=model, device=device, action_ids=action_ids)\n",
    "    \n",
    "    rewards, actions, observations, env = run_with_stream(env_func, length, stream, env_seed, action_func, False)\n",
    "\n",
    "    #print(\"Stream\")\n",
    "    #for i, s in enumerate(stream[:, \"struct\"]):\n",
    "    #    print(f\"{i}: {s}\")\n",
    "    #print(\"Tokens\")\n",
    "    #tok = stream[:, \"tokens\"]\n",
    "    #for i, t in enumerate(tok[\"ids\"]):\n",
    "    #    print(f\"{tok['ref'][i]}: {t}\")\n",
    "    \n",
    "    return rewards, actions, observations, env\n",
    "\n",
    "run_meta_RL_pre = partial(run_meta_RL, model=model, tokenizer=tokenizer, device=device, action_ids=action_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {
    "id": "aeKLdx3h8Xip",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_returns_over_episodes(run_func, func, length, trials):\n",
    "    \"\"\"\n",
    "    Returns a 2D array: (trials, episodes) with the return from each episode in each trial.\n",
    "    \"\"\"\n",
    "    rewards_array = []\n",
    "    actions_array = []\n",
    "    observations_array = []\n",
    "    env_array = []\n",
    "    for t in range(trials):\n",
    "        print(f\"\\rRun {t+1} of {trials}\", end=\"\")\n",
    "        rewards, actions, observations, env = run_func(env_func=func, length=length, env_seed=env_seed)\n",
    "        rewards_array.append(rewards)\n",
    "        actions_array.append(actions)\n",
    "        observations_array.append(observations)\n",
    "        env_array.append(env)\n",
    "    print(\"\\r\")\n",
    "    return rewards_array, actions_array, observations_array, env_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {
    "id": "2cc75041-8a60-4a9a-90c8-d0973d7a596c"
   },
   "outputs": [],
   "source": [
    "# For each method, collect returns over episodes\n",
    "algos = {\n",
    "    'Random': {\"func\": run_random, \"trials\": baseline_trials},\n",
    "    'meta-RL': {\"func\": run_meta_RL_pre, \"trials\": meta_RL_trials},\n",
    "    'Oracle': {\"func\": run_max, \"trials\": baseline_trials},\n",
    "}\n",
    "\n",
    "results = {}\n",
    "for name, details in algos.items():\n",
    "    print(f\"Running {name}...\")\n",
    "    if details[\"trials\"] > 0:\n",
    "        returns, actions, observations, env = collect_returns_over_episodes(details[\"func\"], func, length, details[\"trials\"])\n",
    "        results[name] = {\n",
    "            \"returns\": returns,\n",
    "            \"actions\": actions,\n",
    "            \"observations\": observations,\n",
    "            \"env\": env,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def safe_sum(lst):\n",
    "    return sum(x for x in lst if x is not None)\n",
    "\n",
    "def count(lst):\n",
    "    unique, counts = np.unique(lst, return_counts=True)\n",
    "    return dict(zip(unique.tolist(), counts.tolist()))\n",
    "\n",
    "def calc_results(result):\n",
    "    returns = np.array([[safe_sum(rr) for rr in r] for r in result[\"returns\"]])\n",
    "\n",
    "    observations = result[\"observations\"]\n",
    "    observations_at_timestep = [[] for _ in range(len(observations[0]))]\n",
    "    for trial in observations:\n",
    "        for timestep, episode in enumerate(trial):\n",
    "            observations_at_timestep[timestep].extend(episode)\n",
    "    observation_counts_at_timestep = [count(x) for x in observations_at_timestep]\n",
    "    \n",
    "    calculation = {\n",
    "        \"trials\": returns.shape[0],\n",
    "        \"avg_returns_at_timestep\": np.mean(returns, axis=0),\n",
    "        \"var_returns_at_timestep\": np.var(returns, axis=0),\n",
    "        \"avg_returns\": np.mean(returns, axis=1).mean(axis=0),\n",
    "        \"var_returns\": np.mean(returns, axis=1).var(axis=0),\n",
    "        \"observation_counts_at_timestep\": observation_counts_at_timestep,\n",
    "        \"grid\": result[\"env\"][0].gridmap,\n",
    "    }\n",
    "    return calculation\n",
    "\n",
    "def scale(x, x1, y1, x2, y2):\n",
    "    denom = x2 - x1\n",
    "    if np.any(denom == 0):\n",
    "        raise ValueError(\"x1 and x2 cannot be equal (division by zero)\")\n",
    "    return (y2 - y1) / denom * (x - x1) + y1\n",
    "\n",
    "ref1_calculation = calc_results(results['Random'])\n",
    "ref2_calculation = calc_results(results['Oracle'])\n",
    "\n",
    "calculations = {}\n",
    "for name, result in results.items():\n",
    "    print(f\"Calculating {name}...\")\n",
    "    calculation = calc_results(result)\n",
    "    calculation[\"avg_returns_at_timestep\"] = scale(calculation[\"avg_returns_at_timestep\"], ref1_calculation[\"avg_returns_at_timestep\"], 0.0, ref2_calculation[\"avg_returns_at_timestep\"], 1.0)\n",
    "    calculation[\"var_returns_at_timestep\"] = calculation[\"var_returns_at_timestep\"] / ((ref1_calculation[\"avg_returns_at_timestep\"] - ref2_calculation[\"avg_returns_at_timestep\"])**2)\n",
    "    calculation[\"avg_returns\"] = scale(calculation[\"avg_returns\"], ref1_calculation[\"avg_returns\"], 0.0, ref2_calculation[\"avg_returns\"], 1.0)\n",
    "    calculation[\"var_returns\"] = calculation[\"var_returns\"] / ((ref1_calculation[\"avg_returns\"] - ref2_calculation[\"avg_returns\"])**2)\n",
    "    calculations[name] = calculation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "markers = {\n",
    "    'Random': 'o',              # Circle\n",
    "    'Epsilon-Greedy': 's',      # Square\n",
    "    'Thompson Sampling': '^',   # Triangle\n",
    "    'meta-RL': 'v',             # Down triangle\n",
    "    'Oracle': 'D'               # Diamond\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,6))\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "for name, result in calculations.items():\n",
    "    x_vals = np.arange(result[\"avg_returns_at_timestep\"].size) + 1\n",
    "    y_mean = result[\"avg_returns_at_timestep\"]\n",
    "    y_conf_interval = 1.96 * (result[\"var_returns_at_timestep\"] / result[\"trials\"])**(1/2)\n",
    "    mean = result['avg_returns']\n",
    "    conf_interval = 1.96 * (result['var_returns'] / result[\"trials\"])**(1/2)\n",
    "\n",
    "    # Plot the mean curve\n",
    "    plt.plot(\n",
    "        x_vals,\n",
    "        y_mean,\n",
    "        label=f\"{name} ({mean:.3f} ±{conf_interval:.3f})\",\n",
    "        marker=markers[name],\n",
    "        markevery=1,\n",
    "    )\n",
    "\n",
    "    # Plot shaded region for ±1 std dev\n",
    "    plt.fill_between(\n",
    "        x_vals,\n",
    "        y_mean - y_conf_interval,\n",
    "        y_mean + y_conf_interval,\n",
    "        alpha=0.3,\n",
    "        label=None  # no extra legend entry\n",
    "    )\n",
    "\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Average Reward')\n",
    "plt.xlim(0, length + 1)\n",
    "plt.ylim(-0.2, 1.1)\n",
    "plt.legend(\n",
    "    loc='lower center',\n",
    "    bbox_to_anchor=(0.5, 0.0),\n",
    "    ncol=2,\n",
    "    frameon=True\n",
    ")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_state_map(result, timesteps):\n",
    "    rows = len(result['grid'])\n",
    "    cols = len(result['grid'][0])\n",
    "    grid = [list(row) for row in result['grid']]\n",
    "\n",
    "    # Build a 2D numpy array of visitation counts\n",
    "    heatmap = np.zeros((rows, cols))\n",
    "    for t in timesteps:\n",
    "        for state, count in result['observation_counts_at_timestep'][t].items():\n",
    "            # state might be string or int; ensure int\n",
    "            s = int(state)\n",
    "            r = s // cols\n",
    "            c = s % cols\n",
    "            if 0 <= r < rows and 0 <= c < cols:\n",
    "                heatmap[r, c] += count  # Only fill valid cells\n",
    "            else:\n",
    "                warnings.warn(\"Invalid map\", UserWarning)\n",
    "\n",
    "    #heatmap = np.float32(heatmap > 0)\n",
    "    heatmap_norm = np.maximum(heatmap / heatmap.sum(), 1e-3)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(cols, rows))\n",
    "    gray_cmap = LinearSegmentedColormap.from_list('custom_gray', ['1.0', '0.5'])\n",
    "    im = ax.imshow(heatmap_norm, cmap=gray_cmap, interpolation='nearest', norm=LogNorm(vmin=1e-3, vmax=1.0))\n",
    "\n",
    "    cell_names = {\n",
    "        'S': 'Start',\n",
    "        'H': 'Hole',\n",
    "        'G': 'Goal'\n",
    "    }\n",
    "\n",
    "    # Optionally, gray out holes\n",
    "    for r in range(rows):\n",
    "        for c in range(cols):\n",
    "            grid_letter = grid[r][c]\n",
    "            full_name = cell_names.get(grid_letter, \"\")\n",
    "            if full_name:\n",
    "                ax.text(c, r-0.15, f\"{full_name}\", ha='center', va='center', color='black', fontsize=12)\n",
    "            ax.text(c, r+0.18, f\"{(100*heatmap[r, c]/heatmap.sum()):.1f}%\", ha='center', va='center', color='black', fontsize=7)\n",
    "\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.set_xticks(np.arange(-0.5, cols, 1), minor=True)\n",
    "    ax.set_yticks(np.arange(-0.5, rows, 1), minor=True)\n",
    "    ax.grid(which='minor', color='black', linestyle='-', linewidth=1)\n",
    "    ax.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False, bottom=False, left=False)\n",
    "    plt.show()\n",
    "\n",
    "plot_state_map(calculations['meta-RL'], list(range(0, 3)))\n",
    "plot_state_map(calculations['meta-RL'], list(range(3, 7)))\n",
    "plot_state_map(calculations['meta-RL'], list(range(7, 30)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "xNdF5r8H8F_e",
    "Ff7I5RPB7-DQ",
    "eF0w5k4_8NBc"
   ],
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
