{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {
    "id": "xNdF5r8H8F_e",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {
    "executionInfo": {
     "elapsed": 46,
     "status": "ok",
     "timestamp": 1753529019710,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "EUXSTQOd_Ajw"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import gymnasium as gym\n",
    "from gymnasium.envs.registration import register\n",
    "import hashlib\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import huggingface_hub\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from transformers import StaticCache, SlidingWindowCache, DynamicCache"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {
    "id": "yM7_l-DSGAXn"
   },
   "outputs": [],
   "source": [
    "baseline_trials = 10000\n",
    "meta_RL_trials = 1000\n",
    "length = 30\n",
    "arms = 3\n",
    "max_context_length = 64\n",
    "model_path=\"meta-llama/Llama-3.2-3B\"\n",
    "adapter_path=\"username/path_to_trained_adaptor\"\n",
    "env_seed=None\n",
    "action_space=[\"0\", \"1\", \"2\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {
    "id": "XeU7Fz8L8CwF",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {
    "executionInfo": {
     "elapsed": 30,
     "status": "ok",
     "timestamp": 1753528979771,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "euOHwHF29NcT"
   },
   "outputs": [],
   "source": [
    "class Environment:\n",
    "\n",
    "    def __init__(self, make_kwargs, random_action_prob=0.0):\n",
    "        self.make_kwargs = make_kwargs\n",
    "        self.random_action_prob = random_action_prob\n",
    "\n",
    "        self.env = gym.make(**self.make_kwargs)\n",
    "\n",
    "        self.done = True\n",
    "        self.observations = []\n",
    "        self.rewards = []\n",
    "\n",
    "        self.record = False\n",
    "        self.video = []\n",
    "        self.text = []\n",
    "\n",
    "        self.action_map = {}\n",
    "\n",
    "    def reset(self, record=False, **kwargs):\n",
    "        self.record = record\n",
    "        observation, _ = self.env.reset(**kwargs)\n",
    "        observation_mapped = self.map_observation(observation)\n",
    "        self.done = False\n",
    "\n",
    "        self.observations = []\n",
    "        self.rewards = []\n",
    "        self.observations.append(observation_mapped)\n",
    "\n",
    "        if self.record:\n",
    "            self.video = []\n",
    "            self.text = []\n",
    "            self.video.append(self.env.render())\n",
    "            self.text.append(f\"observation {observation_mapped}\")\n",
    "\n",
    "        return observation_mapped\n",
    "\n",
    "    def step(self, action, **kwargs):\n",
    "        if self.is_action_valid(action):\n",
    "            action_mapped = self.map_action(action)\n",
    "            random_action = self.random_action_prob > random.random()\n",
    "            action_taken = self.env.action_space.sample() if random_action else action_mapped\n",
    "            observation, reward, terminated, truncated, _ = self.env.step(action_taken, **kwargs)\n",
    "            observation_mapped = self.map_observation(observation)\n",
    "            if random_action:\n",
    "                observation_mapped = observation_mapped + '*'\n",
    "            reward_mapped = self.map_reward(reward)\n",
    "        else:\n",
    "            observation_mapped, reward_mapped, terminated, truncated = self.invalid_action_result(action)\n",
    "\n",
    "        self.done = terminated or truncated\n",
    "\n",
    "        self.observations.append(observation_mapped)\n",
    "        self.rewards.append(reward_mapped)\n",
    "\n",
    "        if self.record:\n",
    "            self.video.append(self.env.render())\n",
    "            self.text.append(f\"action {action} observation {observation_mapped} reward {reward_mapped} terminated {terminated} truncated {truncated}\")\n",
    "\n",
    "        return observation_mapped, reward_mapped, terminated, truncated\n",
    "\n",
    "    def is_done(self):\n",
    "        return self.done\n",
    "\n",
    "    def map_action(self, action):\n",
    "        return self.action_map[action]\n",
    "\n",
    "    def map_reward(self, reward):\n",
    "        return float(reward)\n",
    "\n",
    "    def map_observation(self, observation):\n",
    "        return str(observation)\n",
    "\n",
    "    def is_action_valid(self, action):\n",
    "        return action in self.action_map.keys()\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "    @staticmethod\n",
    "    def pick_from_dict(map_dict, key):\n",
    "        try:\n",
    "            value = map_dict[key]\n",
    "        except KeyError:\n",
    "            key_list = list(map_dict.keys())\n",
    "            hash_bytes = hashlib.sha256(key.encode(\"utf-8\")).digest()\n",
    "            hash_int = int.from_bytes(hash_bytes, byteorder=\"big\")\n",
    "            index = hash_int % len(key_list)\n",
    "            corrected_key = key_list[index]\n",
    "            value = map_dict[corrected_key]\n",
    "        return value\n",
    "\n",
    "    def sample_action(self):\n",
    "        return random.choice(list(self.action_map.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {
    "executionInfo": {
     "elapsed": 47,
     "status": "ok",
     "timestamp": 1753528979814,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "DbgmSZtq9VA0"
   },
   "outputs": [],
   "source": [
    "class MultiArmedBanditGymEnv(gym.Env):\n",
    "    \"\"\"\n",
    "    Multi-Armed Bandit Environment with Bernoulli rewards and its own random number generator (rng).\n",
    "    \"\"\"\n",
    "    metadata = {\"render_modes\": [\"human\"]}\n",
    "\n",
    "    def __init__(self, k=10, probs=None, seed=None):\n",
    "        super().__init__()\n",
    "        self.k = k\n",
    "        self.action_space = gym.spaces.Discrete(k)\n",
    "        self.observation_space = gym.spaces.Discrete(1)  # Dummy observation\n",
    "        self.seed(seed)\n",
    "\n",
    "        # Each arm has a probability of reward=1 (success), must be in [0,1]\n",
    "        if probs is None:\n",
    "            self.arm_probs = self.np_random.uniform(0, 1, k)\n",
    "        else:\n",
    "            self.arm_probs = np.array(probs)\n",
    "        self.last_action = None\n",
    "\n",
    "    def seed(self, seed=None):\n",
    "        self.np_random, _ = gym.utils.seeding.np_random(seed)\n",
    "\n",
    "    def reset(self, seed=None, options=None):\n",
    "        if seed is not None:\n",
    "            self.seed(seed)\n",
    "        self.last_action = None\n",
    "        return 0, {}  # Dummy observation, info\n",
    "\n",
    "    def step(self, action):\n",
    "        assert self.action_space.contains(action), \"Invalid action\"\n",
    "        # Reward is 1 with probability arm_probs[action], else 0\n",
    "        reward = self.np_random.binomial(1, self.arm_probs[action])\n",
    "        self.last_action = action\n",
    "        done = False\n",
    "        return 0, reward, done, False, {}  # obs, reward, terminated, truncated, info\n",
    "\n",
    "    def render(self, mode=\"human\"):\n",
    "        print(f\"Last action: {self.last_action}\")\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "register(\n",
    "    id=\"MultiArmedBandit-v0\",  # Unique identifier for the environment\n",
    "    entry_point=MultiArmedBanditGymEnv,  # module:class\n",
    "    max_episode_steps=1,  # Bandit problems are usually one-step episodes\n",
    ")\n",
    "\n",
    "class MultiArmBandit(Environment):\n",
    "\n",
    "    def __init__(self, max_episode_steps=1, arms=10, random_action_prob=0.0, seed=None):\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        probs = rng.uniform(0, 1, size=arms)\n",
    "        \n",
    "        make_kwargs = {\n",
    "            \"id\": \"MultiArmedBandit-v0\",\n",
    "            \"max_episode_steps\": max_episode_steps,\n",
    "            \"k\": arms,\n",
    "            \"probs\": probs,\n",
    "        }\n",
    "        super().__init__(make_kwargs, random_action_prob=random_action_prob)\n",
    "\n",
    "        self.action_map = {str(i): i for i in range(arms)}\n",
    "\n",
    "    def reset(self, record=False, **kwargs):\n",
    "        return super().reset(record=False, **kwargs)\n",
    "\n",
    "    def map_observation(self, observation):\n",
    "        return None\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {
    "executionInfo": {
     "elapsed": 20,
     "status": "ok",
     "timestamp": 1753528979837,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "1QEyqQ8Y9dbV"
   },
   "outputs": [],
   "source": [
    "func = partial(\n",
    "    MultiArmBandit,\n",
    "    max_episode_steps=1,\n",
    "    random_action_prob=0.0,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {
    "id": "Ff7I5RPB7-DQ",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {
    "executionInfo": {
     "elapsed": 24,
     "status": "ok",
     "timestamp": 1753528979863,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "HOMAxsBa70hA"
   },
   "outputs": [],
   "source": [
    "class NumpyBuffer:\n",
    "    def __init__(self, dtype, size_increment):\n",
    "        self.dtype = dtype\n",
    "        self.size_increment = size_increment\n",
    "        self.storage = np.empty(self.size_increment, dtype=self.dtype)\n",
    "        self.index = 0\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.index\n",
    "\n",
    "    def increase_size(self):\n",
    "        new_size = self.storage.shape[0] + self.size_increment\n",
    "        new_storage = np.empty(new_size, dtype=self.dtype)\n",
    "        new_storage[:self.index] = self.storage[:self.index]\n",
    "        self.storage = new_storage\n",
    "\n",
    "    def add(self, data):\n",
    "        data = np.atleast_1d(data)\n",
    "        assert data.dtype == self.storage.dtype\n",
    "        assert len(data.shape) == 1\n",
    "        length = data.shape[0]\n",
    "        new_index = self.index + length\n",
    "        while self.storage.shape[0] < new_index:\n",
    "            self.increase_size()\n",
    "        self.storage[self.index:new_index] = data\n",
    "        self.index = new_index\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.storage[:self.index][idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {
    "editable": true,
    "executionInfo": {
     "elapsed": 40,
     "status": "ok",
     "timestamp": 1753528979891,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "pOX0bjwmmKEd",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Stream:\n",
    "    def __init__(self, tokenizer, size_increment):\n",
    "        self.struct = []\n",
    "        self.token_dtype = np.dtype([\n",
    "            (\"ids\", np.int64),\n",
    "            (\"types\", np.int64),\n",
    "            (\"episodes\", np.int64),\n",
    "            (\"groups\", np.int64),\n",
    "            (\"rewards\", np.float64),\n",
    "            (\"elements\", np.int64),\n",
    "        ])\n",
    "        self.tokens = NumpyBuffer(dtype=self.token_dtype, size_increment=size_increment)\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "        self.roles = {\n",
    "            \"environment\": {\n",
    "                \"header\": \"environment\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"begin\": {\n",
    "                \"header\": \"begin\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"thought\": {\n",
    "                \"header\": \"thought:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"observation\": {\n",
    "                \"header\": \"observation:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"action\": {\n",
    "                \"header\": \"action:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"reward\": {\n",
    "                \"header\": \"reward:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"end\": {\n",
    "                \"header\": \"end\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "        }\n",
    "\n",
    "        # Precompute header and footer token IDs for each role\n",
    "        type_id = 0\n",
    "        for role_name, role in self.roles.items():\n",
    "            role[\"header_ids\"] = self.tokenizer.encode(role[\"header\"], add_special_tokens=False)\n",
    "            role[\"footer_ids\"] = self.tokenizer.encode(role[\"footer\"], add_special_tokens=False)\n",
    "            role[\"header_type\"] = type_id\n",
    "            role[\"content_type\"] = type_id + 1\n",
    "            role[\"footer_type\"] = type_id + 2\n",
    "            type_id += 3\n",
    "\n",
    "    def add_struct(self, struct):\n",
    "        tokens = self.struct_to_tokens(struct)\n",
    "        struct_length = len(self.struct)\n",
    "        tokens[\"elements\"] = [x + struct_length for x in tokens[\"elements\"]]\n",
    "        self.struct.extend(struct)\n",
    "        self.tokens.add(tokens)\n",
    "\n",
    "    def __getitem__(self, key):\n",
    "        if isinstance(key, tuple):\n",
    "            idx, mode = key\n",
    "            if mode == \"struct\":\n",
    "                return self.struct[idx]\n",
    "            elif mode == \"tokens\":\n",
    "                return self.tokens[idx]\n",
    "        else:\n",
    "            return self.struct[key]\n",
    "\n",
    "    def get_tokens_length(self):\n",
    "        return len(self.tokens)\n",
    "\n",
    "    def tokenize(self, content):\n",
    "        return self.tokenizer.encode(content, add_special_tokens=False)\n",
    "\n",
    "    def detokenize(self, content_ids):\n",
    "        return self.tokenizer.decode(content_ids, skip_special_tokens=False)\n",
    "\n",
    "    def struct_to_tokens(self, struct):\n",
    "        ids = []\n",
    "        types = []\n",
    "        episodes = []\n",
    "        groups = []\n",
    "        rewards = []\n",
    "        elements = []\n",
    "\n",
    "        for i, seg in enumerate(struct):\n",
    "            role_name = seg[\"role\"]\n",
    "\n",
    "            try:\n",
    "                role = self.roles[role_name]\n",
    "            except KeyError:\n",
    "                raise KeyError(f\"Unknown role: {role_name}\")\n",
    "\n",
    "            content_ids = self.tokenize(str(seg[\"content\"]))\n",
    "            full_ids = role[\"header_ids\"] + content_ids + role[\"footer_ids\"]\n",
    "            ids.extend(full_ids)\n",
    "            types.extend(\n",
    "                [role[\"header_type\"]] * len(role[\"header_ids\"])\n",
    "                + [role[\"content_type\"]] * len(content_ids)\n",
    "                + [role[\"footer_type\"]] * len(role[\"footer_ids\"])\n",
    "            )\n",
    "            episodes.extend([int(seg[\"episode\"])] * len(full_ids))\n",
    "            groups.extend([int(seg[\"group\"])] * len(full_ids))\n",
    "            rewards.extend([float(seg[\"reward\"])] + [0.0] * (len(full_ids) - 1))\n",
    "            elements.extend([i] * len(full_ids))\n",
    "\n",
    "        tokens = np.zeros(len(ids), dtype=self.token_dtype)\n",
    "        tokens['ids'] = ids\n",
    "        tokens['types'] = types\n",
    "        tokens['episodes'] = episodes\n",
    "        tokens['groups'] = groups\n",
    "        tokens['rewards'] = rewards\n",
    "        tokens['elements'] = elements\n",
    "\n",
    "        return tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {
    "id": "eF0w5k4_8NBc",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1753528979902,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "g4QTS1fnGy0W"
   },
   "outputs": [],
   "source": [
    "def run_random(func, length, arms, env_seed):\n",
    "    env = func(arms=arms, seed=env_seed)\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "\n",
    "    for t in range(length):\n",
    "        obs = env.reset()\n",
    "\n",
    "        action = str(rng.integers(arms))\n",
    "\n",
    "        actions.append(action)\n",
    "        obs, reward, terminated, truncated = env.step(action)\n",
    "        rewards.append(reward)\n",
    "\n",
    "    return rewards, actions, env.env.unwrapped.arm_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1753528979917,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "-P3zDV_w-8X0"
   },
   "outputs": [],
   "source": [
    "def run_epsilon_greedy(func, length, arms, env_seed):\n",
    "    env = func(arms=arms, seed=env_seed)\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "    \n",
    "    q_values = np.zeros(arms)    # Estimated values for each arm\n",
    "    n_pulls = np.zeros(arms)     # Number of times each arm was selected\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    \n",
    "    epsilon = 0.1\n",
    "\n",
    "    for t in range(length):\n",
    "        obs = env.reset()\n",
    "\n",
    "        # Epsilon-greedy action selection\n",
    "        if rng.uniform() < epsilon:\n",
    "            action_num = rng.integers(arms)\n",
    "        else:\n",
    "            max_indices = np.flatnonzero(q_values == np.max(q_values))\n",
    "            action_num = rng.choice(max_indices)\n",
    "        action = str(action_num)\n",
    "        \n",
    "        actions.append(action)\n",
    "        obs, reward, terminated, truncated = env.step(action)\n",
    "        rewards.append(reward)\n",
    "\n",
    "        # Update Q-value for the selected arm (incremental mean)\n",
    "        n_pulls[action_num] += 1\n",
    "        q_values[action_num] += (reward - q_values[action_num]) / n_pulls[action_num]\n",
    "\n",
    "    return rewards, actions, env.env.unwrapped.arm_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {
    "executionInfo": {
     "elapsed": 182,
     "status": "ok",
     "timestamp": 1753528980105,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "wC7F6HaO-8bK"
   },
   "outputs": [],
   "source": [
    "def run_thompson_sampling(func, length, arms, env_seed):\n",
    "    env = func(arms=arms, seed=env_seed)\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    \n",
    "    # Each arm's probability of reward is modeled as a Bernoulli with a Beta prior over p\n",
    "    alpha = np.ones(arms)  # Number of observed successes + 1 (prior)\n",
    "    beta = np.ones(arms)   # Number of observed failures + 1 (prior)\n",
    "    \n",
    "    for t in range(length):\n",
    "        obs = env.reset()\n",
    "        # For each arm, sample a probability of success from its Beta posterior\n",
    "        p_samples = rng.beta(alpha, beta)\n",
    "        action_num = np.argmax(p_samples)\n",
    "        action = str(action_num)\n",
    "\n",
    "        actions.append(action)\n",
    "        obs, reward, terminated, truncated = env.step(action)\n",
    "        rewards.append(reward)\n",
    "\n",
    "        # Update the Beta posterior for the chosen arm\n",
    "        if reward > 0:\n",
    "            alpha[action_num] += 1\n",
    "        else:\n",
    "            beta[action_num] += 1\n",
    "\n",
    "    return rewards, actions, env.env.unwrapped.arm_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1753528980120,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 300
    },
    "id": "3aSjJ_bs-8eL"
   },
   "outputs": [],
   "source": [
    "def run_max(func, length, arms, env_seed):\n",
    "    env = func(arms=arms, seed=env_seed)\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "\n",
    "    for t in range(length):\n",
    "        obs = env.reset()\n",
    "\n",
    "        action = str(np.argmax(env.env.unwrapped.arm_probs))\n",
    "\n",
    "        actions.append(action)\n",
    "        obs, reward, terminated, truncated = env.step(action)\n",
    "        rewards.append(reward)\n",
    "    \n",
    "    return rewards, actions, env.env.unwrapped.arm_probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {
    "id": "XFoJ96Ah8cYh",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Meta-RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 591,
     "referenced_widgets": [
      "cdc43c5e72314048b9ba42921bdae13f",
      "0ce586a5a7794d25a896c715dcb6d2c5",
      "a510a866c9a94d89b78592826099ae12",
      "95420d7a4457429e812ebd8df713982b",
      "bf2f75a0fdab459e98bedea30d01d27a",
      "f61aacc262714418ab0f275249c12da5",
      "49bfd14afa40429cbf7f452a851a15b4",
      "c0eca66030f346f497cab28fb3d9bb37",
      "cd66fab1e03d4ba7b096e5d7cba960d6",
      "114b740337db498c9aef61098ec8ff61",
      "227777048e054cbd8f13638fda2c7442",
      "d4f3671a3b44440982dc46472a2da12b",
      "873f20a82e2340429cea2f1a48b4bf54",
      "b96bf7d5e922471db3d49f0acda1a13b",
      "2b00fa2d36334dc7a7fc296db0c578c4",
      "0517378f3e8a4dce9cd040a7289c15c0",
      "65b5fc640bee45e7a6c86719eba893b2",
      "66b3269f36e3401cb84926ceab65ef50",
      "a83460b1d4004b149ce407df5c05c7c4",
      "567023535c2b4a8da26601689eb8e19f",
      "56a6407352414743b123341423c78d48",
      "33758e4a1df34c588f334ea4cf01a678",
      "d6019ffc0b284ca8843bb5a3d2cc53d7",
      "f343031ef359460daf8fed0968b4fe4a",
      "2bfb50d38dca405b91d73cd682dd8be8",
      "5af0e713bc4c41a092e76969c0560557",
      "d928085a1fef4c3fada04007caf7724a",
      "0b66e67f15634e7e9e29bd36629fce58",
      "2d548b5cf67e4690a7294e13e0d75696",
      "1dec6971d41b41cf91a2d36793dfa6a6",
      "e0f24675e6ab43d68029472745699a1f",
      "138fabf1598f4dadb5bb8889d9cce168",
      "16723cbe1a064467ac367cf30aa47861",
      "980ace0a44ec4464b52f8f2947b99fc2",
      "2c1791b0e5904bcba832e78d6bbe7ab6",
      "4fe4ae6e9b7a4cb9974b25f0c353c7ed",
      "99cda24ade694ceab00fa544b63421f3",
      "055f9fe449d84304b2fd88efe0fe39fc",
      "70c0cae9a8a246bf8054f00ddeff5c2b",
      "b45b9c5178994841884cf5d23d8b0b33",
      "b18d9c143214465a98b82f80db99c8b0",
      "76b57e0ab63f4adb95b516e6a8b31595",
      "51bd48cf4d0941bd8106256a29e8a723",
      "bceb2586f91741559c21e34bfe7a325c",
      "3df732e64b0049cc9a26a6e30d03ae02",
      "11753fc2407d473da19373a22c5d2fcc",
      "ce402a1cc4b440b69a322467eedc303c",
      "3b4d285f89e34894a472817d52c8aab4",
      "aa308070bbe34f45a2732543ceb86a76",
      "45ab9852986f4e38a9e3ba15669db79b",
      "16afce42d19645c890f3038d2e0e17ff",
      "6b63270a429840c59060e9a1fda10ef5",
      "4f0ad5654e6648ffaed10d6738268f7d",
      "20ae8cb2168a483b8e30370a97308a61",
      "ba96312e1e6f493cad6d010ed5e7e313",
      "eac45d7bb3544e65899969e868c53398",
      "0cb4ddfbbbf045d1940d35aa51cc7ccb",
      "a8f1eff795054b27b41b2aa1a66ed281",
      "311edc4861254d05bc7a00243e12dd27",
      "3cf973a1184640248c7f70e6a31fe87f",
      "3f24896064224740a7df3c2289f69a30",
      "b20b1de2b0a54f23808915d60da1f8f2",
      "d47065e2ba3c47e2bfead8659703219a",
      "be93a76c1d7c4e039fa1f4ccb9acb853",
      "4770842163bc4e9f893c588e1980a7ca",
      "adf13c2c1f79408a99edb5decd4d369f",
      "4e41ea515fb54a08bf552734543e33e9",
      "9192b7b54ad44579ac0a813bbc83275d",
      "2afc78dea0c64ce2b2b69252b91ca21e",
      "2f3e765468bc4b7a95b5d584ae7d92a1",
      "4ee0aba64d444b1b8adaaf2ebc25de36",
      "10ab9e82a71a4b438dfc8a5cd024308d",
      "926024d4ecdc47a0b1cf8d26832dc6df",
      "d6e542bdf9fd460da6f980dafc491774",
      "a14f7e4ac4d54edd8ab15f7701908f84",
      "127ed116c77c45bdbe57d874ad5b7536",
      "67b51117b6404189a66a319ccecf6f6b",
      "bacc36d3bebf4541939903097e4bf9f9",
      "122a982e414840a3bf2fc392e7b60fdf",
      "2d759cdeed264a2c93339506bd490ed1",
      "221971a1808d4862bfc2f7b37ef100c4",
      "502702bf74ea48f099f464ec864d2ed5",
      "224917b35cc54e20865abd5df5e10589",
      "2de17e43c3464b248d568ca3d4af27bf",
      "a0af209cfe2744acb523c48e14454732",
      "a806d6a828814303b715a02684d8eddb",
      "4679d80ec11b4d9890fab8d7ba58f2f1",
      "94a3c57ded6b4606894aec25ea41c871",
      "c88e32c1017a4e38ba5c2665f9c089b9",
      "dd07226be92e47b9b0627c365ed80e60",
      "966ac06517294959a5a43a7f3b56f2c6",
      "3c619bce5e1c4f63a7aad87283928e1e",
      "95bb28cd843e4e67ae5ccd008954a2b9",
      "6bcf7370896b494798040a37b4551ed1",
      "7a533b32014b45779edfa351125ecb0f",
      "86960950865f4b00acabf72da49b4b15",
      "74175c03e98a4f16b5f8ca62c6fadef0",
      "6dcc34143ec74363bc28fb58e5d2be5c",
      "6389886ddbfe43f8af1a5265932c271a",
      "3c8d45369080484cbc1c690fd8461a55",
      "0ab1e351d48e4eeb9166c9a47e7b35de",
      "31a88b8a083d49f78ba9fabd71a50903",
      "de66e1dfd4c24dc9af8a217bb6c60182",
      "c01012c58bf94a4cb9c2b243db79b431",
      "ad27e16c5d1045038f6d6e497b5666e4",
      "7fae8bd9bb1e41858c7b7935e4915222",
      "8d79dd3051ff44c7a1807dc12308495e",
      "e7e7da976d3140abac9b555bf646957b",
      "1322202ec40a4e5a830cd1ae08814e32",
      "50ef1906bafd4b83aaf423f7ccca3270",
      "faa402ec7acb4e0c836ce8626cbfa040",
      "a8efe201a76f44d9b47a88b8f1656f65",
      "5babb4e863974e1eb326c781ac1f4124",
      "5c0ef002797b449197053360a6817ead",
      "8f27048aee814180accce254566a1b5f",
      "d94798644b9c4adfb0868168856507fb"
     ]
    },
    "id": "cBdGEQf1eMgq",
    "outputId": "83c5d0c7-2e38-40f3-9795-9c905e22f66e"
   },
   "outputs": [],
   "source": [
    "load_dotenv()\n",
    "huggingface_hub.login(token=os.environ.get('_HF_TOKEN'))\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device {device}\")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_path,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\",\n",
    "    cache_dir=\"../model/\",\n",
    ")\n",
    "base_model.lm_head = torch.nn.Linear(base_model.lm_head.in_features, base_model.lm_head.out_features, bias=True, dtype=base_model.lm_head.weight.dtype)\n",
    "\n",
    "if adapter_path is not None:\n",
    "    model = PeftModel.from_pretrained(base_model, adapter_path, force_download=True).merge_and_unload().to(device)\n",
    "\n",
    "model.eval()\n",
    "\n",
    "#print(model)\n",
    "\n",
    "action_ids = []\n",
    "for k in action_space:\n",
    "    tokens = tokenizer.encode(k, add_special_tokens=False)\n",
    "    if len(tokens) != 1:\n",
    "        raise ValueError(f\"Action '{k}' tokenized into {len(tokens)} tokens: {tokens}. Max length is one.\")\n",
    "    action_ids.append(tokens[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {
    "id": "ZfGsfSoKetNx"
   },
   "outputs": [],
   "source": [
    "def run_with_stream(func, length, arms, stream, env_seed, action_func):\n",
    "    env = func(arms=arms, seed=env_seed)\n",
    "\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    \n",
    "    group = 1\n",
    "    \n",
    "    stream.add_struct([{\"role\": \"environment\", \"content\": \"\", \"episode\": 1, \"group\": group, \"reward\": 0.0}])\n",
    "    \n",
    "    for t in range(length):\n",
    "        obs = env.reset()\n",
    "\n",
    "        stream.add_struct([{\"role\": \"begin\", \"content\": \"\", \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "        if obs is not None:\n",
    "            stream.add_struct([{\"role\": \"observation\", \"content\": obs, \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "\n",
    "        action = action_func(stream)\n",
    "\n",
    "        stream.add_struct([{\"role\": \"action\", \"content\": action, \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "\n",
    "        actions.append(action)\n",
    "        next_obs, reward, terminated, truncated = env.step(action)\n",
    "        rewards.append(reward)\n",
    "\n",
    "        if reward is not None:\n",
    "            stream.add_struct([{\"role\": \"reward\", \"content\": reward, \"episode\": t+1, \"group\": group, \"reward\": reward}])\n",
    "        if next_obs is not None:\n",
    "            stream.add_struct([{\"role\": \"observation\", \"content\": next_obs, \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "        if terminated or truncated:\n",
    "            stream.add_struct([{\"role\": \"end\", \"content\": \"\", \"episode\": t+1, \"group\": group, \"reward\": 0.0}])\n",
    "\n",
    "    return rewards, actions, env.env.unwrapped.arm_probs\n",
    "\n",
    "def action_func_transformer(stream, cache, model, device, action_ids):\n",
    "    prefix_ids = torch.tensor(stream.roles[\"action\"][\"header_ids\"], dtype=torch.long, device=device).unsqueeze(0)\n",
    "    \n",
    "    stream_length = stream.get_tokens_length()\n",
    "    target_length = min(stream_length, max_context_length - prefix_ids.shape[1])\n",
    "    d = np.stack([stream[-target_length:, \"tokens\"]])\n",
    "    data = {name: torch.from_numpy(d[name]).to(device) for name in d.dtype.names}\n",
    "\n",
    "    context_ids = data[\"ids\"]   # [1, seq_len]\n",
    "    input_ids = torch.cat([context_ids, prefix_ids], dim=-1)  # [1, T+prefix]\n",
    "\n",
    "    input_length = stream_length - cache[\"alignment_time\"]\n",
    "    input_length = min(input_length, input_ids.shape[1])\n",
    "    input_ids_partial = input_ids[:, -input_length:]\n",
    "    \n",
    "    #max_cache_length = max_context_length - input_length\n",
    "    #if cache[\"past_key_values\"].get_seq_length() > max_cache_length:\n",
    "        #c = cache[\"past_key_values\"]\n",
    "        #for layer in c.layers:\n",
    "            #layer.keys   = layer.keys[..., -max_cache_length:, :]\n",
    "            #layer.values = layer.values[..., -max_cache_length:, :]\n",
    "        \n",
    "    #cache_position = torch.arange(0, input_length, dtype=torch.long, device=device) + cache[\"cache_position\"]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=input_ids_partial,\n",
    "            past_key_values=cache[\"past_key_values\"],\n",
    "            #cache_position=cache_position,\n",
    "            use_cache=True,\n",
    "        )\n",
    "    logits = outputs.logits\n",
    "    cache[\"alignment_time\"] = stream_length\n",
    "    #cache[\"cache_position\"] = cache[\"cache_position\"] + input_length\n",
    "    \n",
    "    action_logits = logits[:, -1, action_ids]\n",
    "    action_idx = torch.argmax(action_logits, dim=-1)\n",
    "    action_token = action_ids[action_idx]\n",
    "    action = stream.detokenize([action_token])\n",
    "    return action\n",
    "\n",
    "def action_func_transformer_wo(stream, model, device, action_ids):\n",
    "    prefix_ids = torch.tensor(stream.roles[\"action\"][\"header_ids\"], dtype=torch.long, device=device).unsqueeze(0)\n",
    "    \n",
    "    stream_length = stream.get_tokens_length()\n",
    "    target_length = min(stream_length, max_context_length - prefix_ids.shape[1])\n",
    "    d = np.stack([stream[-target_length:, \"tokens\"]])\n",
    "    data = {name: torch.from_numpy(d[name]).to(device) for name in d.dtype.names}\n",
    "    \n",
    "    context_ids = data[\"ids\"]   # [1, seq_len]\n",
    "    input_ids = torch.cat([context_ids, prefix_ids], dim=-1)  # [1, T+prefix]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=input_ids,\n",
    "        )\n",
    "    logits = outputs.logits\n",
    "    \n",
    "    action_logits = logits[:, -1, action_ids]\n",
    "    action_idx = torch.argmax(action_logits, dim=-1)\n",
    "    action_token = action_ids[action_idx]\n",
    "    action = stream.detokenize([action_token])\n",
    "    return action\n",
    "\n",
    "def action_func_random(stream, rng):\n",
    "    return str(rng.integers(arms))\n",
    "\n",
    "def run_meta_RL(func, length, arms, model, tokenizer, device, action_ids, env_seed):\n",
    "    stream = Stream(tokenizer=tokenizer, size_increment=max_context_length)\n",
    "    cache = {\"past_key_values\": DynamicCache(), \"alignment_time\": stream.get_tokens_length(), \"cache_position\": 0}\n",
    "    rng = np.random.default_rng(seed=None)\n",
    "\n",
    "    action_func = partial(action_func_random, rng=rng)\n",
    "\n",
    "    while stream.get_tokens_length() < max_context_length:\n",
    "        _, _, _ = run_with_stream(func, length, arms, stream, None, action_func)\n",
    "\n",
    "    action_func = partial(action_func_transformer, cache=cache, model=model, device=device, action_ids=action_ids)\n",
    "    #action_func = partial(action_func_transformer_wo, model=model, device=device, action_ids=action_ids)\n",
    "    \n",
    "    rewards, actions, arm_probs = run_with_stream(func, length, arms, stream, env_seed, action_func)\n",
    "\n",
    "    #print(\"Stream\")\n",
    "    #for i, s in enumerate(stream[:, \"struct\"]):\n",
    "    #    print(f\"{i}: {s}\")\n",
    "    #print(\"Tokens\")\n",
    "    #tok = stream[:, \"tokens\"]\n",
    "    #for i, t in enumerate(tok[\"ids\"]):\n",
    "    #    print(f\"{tok['ref'][i]}: {t}\")\n",
    "    \n",
    "    return rewards, actions, arm_probs\n",
    "\n",
    "run_meta_RL_pre = partial(run_meta_RL, model=model, tokenizer=tokenizer, device=device, action_ids=action_ids, env_seed=env_seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {
    "id": "aeKLdx3h8Xip",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(x, temperature=1.0, axis=-1):\n",
    "    # Temperature scaling\n",
    "    x_scaled = x / temperature\n",
    "    # Subtract max for numerical stability\n",
    "    x_max = np.max(x_scaled, axis=axis, keepdims=True)\n",
    "    e_x = np.exp(x_scaled - x_max)\n",
    "    softmax_out = e_x / np.sum(e_x, axis=axis, keepdims=True)\n",
    "    return softmax_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_kl(actions_x, probs_y):\n",
    "    \"\"\"\n",
    "    actions_x: array of shape (trials, length)\n",
    "    probs_y: array of shape (trials, arms)\n",
    "    \"\"\"\n",
    "    eps = 1e-12\n",
    "    actions_x = actions_x.astype(int)\n",
    "    probs_y_safe = np.clip(probs_y, eps, 1.0)\n",
    "\n",
    "    # Extract q = probs_y[trial, action] for each (trial, time)\n",
    "    trials, length = actions_x.shape\n",
    "    rows = np.arange(trials)[:, None]  # shape (trials, 1)\n",
    "    action_probs = probs_y_safe[rows, actions_x]  # shape (trials, length)\n",
    "    #print(action_probs)\n",
    "\n",
    "    # KL divergence between delta(a) and probs_y is: log(1) - log(q) = -log(q)\n",
    "    kls = -np.log(action_probs)\n",
    "\n",
    "    return kls.mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_entropy(actions, arms):\n",
    "    \"\"\"\n",
    "    actions: array of shape (trials, length)\n",
    "    arms: int with number of arms\n",
    "    \"\"\"\n",
    "    trials, length = actions.shape\n",
    "\n",
    "    arm_labels = np.array([str(a) for a in range(arms)])\n",
    "    action_str = actions.astype(str)\n",
    "    matches = (action_str[None, :, :] == arm_labels[:, None, None])  # (num_arms, trials, length)\n",
    "    probs = matches.mean(axis=1)  # (num_arms, length)\n",
    "\n",
    "    mask = probs > 0\n",
    "    log_probs = np.zeros_like(probs)\n",
    "    log_probs[mask] = np.log(probs[mask])\n",
    "    entropies = -np.sum(probs * log_probs, axis=0)\n",
    "\n",
    "    return entropies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_returns_over_episodes(run_func, func, length, arms, trials):\n",
    "    \"\"\"\n",
    "    Returns a 2D array: (trials, episodes) with the return from each episode in each trial.\n",
    "    \"\"\"\n",
    "    rewards_array = np.empty((trials, length), dtype=np.float64)\n",
    "    actions_array = np.empty((trials, length), dtype=np.dtype('U'))\n",
    "    arm_probs_array = np.empty((trials, arms), dtype=np.float64)\n",
    "    for t in range(trials):\n",
    "        print(f\"\\rRun {t+1} of {trials}\", end=\"\")\n",
    "        rewards, actions, arm_probs = run_func(func, length=length, arms=arms, env_seed=env_seed)\n",
    "        rewards_array[t, :] = rewards\n",
    "        actions_array[t, :] = actions\n",
    "        arm_probs_array[t, :] = arm_probs\n",
    "    print(\"\\r\")\n",
    "    return rewards_array, actions_array, arm_probs_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24",
   "metadata": {
    "id": "2cc75041-8a60-4a9a-90c8-d0973d7a596c"
   },
   "outputs": [],
   "source": [
    "# For each method, collect returns over episodes\n",
    "algos = {\n",
    "    'Random': {\"func\": run_random, \"trials\": baseline_trials},\n",
    "    'Epsilon-Greedy': {\"func\": run_epsilon_greedy, \"trials\": baseline_trials},\n",
    "    'Thompson Sampling': {\"func\": run_thompson_sampling, \"trials\": baseline_trials},\n",
    "    'meta-RL': {\"func\": run_meta_RL_pre, \"trials\": meta_RL_trials},\n",
    "    'Oracle': {\"func\": run_max, \"trials\": baseline_trials},\n",
    "}\n",
    "\n",
    "results = {}\n",
    "for name, details in algos.items():\n",
    "    print(f\"Running {name}...\")\n",
    "    if details[\"trials\"] > 0:\n",
    "        returns, actions, arm_probs = collect_returns_over_episodes(details[\"func\"], func, length, arms, details[\"trials\"])\n",
    "        results[name] = {\n",
    "            \"returns\": returns,\n",
    "            \"actions\": actions,\n",
    "            \"arm_probs\": arm_probs,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_results(result):\n",
    "    oracle_probs = softmax(result[\"arm_probs\"], temperature=0.000000001, axis=-1)\n",
    "    calculation = {\n",
    "        \"trials\": result[\"returns\"].shape[0],\n",
    "        \"avg_returns_at_timestep\": np.mean(result[\"returns\"], axis=0),\n",
    "        \"var_returns_at_timestep\": np.var(result[\"returns\"], axis=0),\n",
    "        \"avg_returns\": np.mean(result[\"returns\"], axis=1).mean(axis=0),\n",
    "        \"var_returns\": np.mean(result[\"returns\"], axis=1).var(axis=0),\n",
    "        \"kl_to_oracle_at_timestep\": compute_kl(result[\"actions\"], oracle_probs),\n",
    "        \"entropy_at_timestep\": compute_entropy(result[\"actions\"], arms),\n",
    "    }\n",
    "    return calculation\n",
    "\n",
    "def scale(x, x1, y1, x2, y2):\n",
    "    denom = x2 - x1\n",
    "    if np.any(denom == 0):\n",
    "        raise ValueError(\"x1 and x2 cannot be equal (division by zero)\")\n",
    "    return (y2 - y1) / denom * (x - x1) + y1\n",
    "\n",
    "ref1_calculation = calc_results(results['Random'])\n",
    "ref2_calculation = calc_results(results['Oracle'])\n",
    "\n",
    "calculations = {}\n",
    "for name, result in results.items():\n",
    "    print(f\"Calculating {name}...\")\n",
    "    calculation = calc_results(result)\n",
    "    calculation[\"avg_returns_at_timestep\"] = scale(calculation[\"avg_returns_at_timestep\"], ref1_calculation[\"avg_returns_at_timestep\"], 0.0, ref2_calculation[\"avg_returns_at_timestep\"], 1.0)\n",
    "    calculation[\"var_returns_at_timestep\"] = calculation[\"var_returns_at_timestep\"] / ((ref1_calculation[\"avg_returns_at_timestep\"] - ref2_calculation[\"avg_returns_at_timestep\"])**2)\n",
    "    calculation[\"avg_returns\"] = scale(calculation[\"avg_returns\"], ref1_calculation[\"avg_returns\"], 0.0, ref2_calculation[\"avg_returns\"], 1.0)\n",
    "    calculation[\"var_returns\"] = calculation[\"var_returns\"] / ((ref1_calculation[\"avg_returns\"] - ref2_calculation[\"avg_returns\"])**2)\n",
    "    calculations[name] = calculation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27",
   "metadata": {},
   "outputs": [],
   "source": [
    "markers = {\n",
    "    'Random': 'o',              # Circle\n",
    "    'Epsilon-Greedy': 's',      # Square\n",
    "    'Thompson Sampling': '^',   # Triangle\n",
    "    'meta-RL': 'v',             # Down triangle\n",
    "    'Oracle': 'D'               # Diamond\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,6))\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "for name, result in calculations.items():\n",
    "    x_vals = np.arange(result[\"avg_returns_at_timestep\"].size) + 1\n",
    "    y_mean = result[\"avg_returns_at_timestep\"]\n",
    "    y_conf_interval = 1.96 * (result[\"var_returns_at_timestep\"] / result[\"trials\"])**(1/2)\n",
    "    mean = result['avg_returns']\n",
    "    conf_interval = 1.96 * (result['var_returns'] / result[\"trials\"])**(1/2)\n",
    "\n",
    "    # Plot the mean curve\n",
    "    plt.plot(\n",
    "        x_vals,\n",
    "        y_mean,\n",
    "        label=f\"{name} ({mean:.3f} ±{conf_interval:.3f})\",\n",
    "        marker=markers[name],\n",
    "        markevery=1,\n",
    "    )\n",
    "\n",
    "    # Plot shaded region for ±1 std dev\n",
    "    plt.fill_between(\n",
    "        x_vals,\n",
    "        y_mean - y_conf_interval,\n",
    "        y_mean + y_conf_interval,\n",
    "        alpha=0.3,\n",
    "        label=None  # no extra legend entry\n",
    "    )\n",
    "\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Average Reward')\n",
    "plt.xlim(0, length + 1)\n",
    "plt.ylim(-0.35, 1.1)\n",
    "plt.legend(\n",
    "    loc='lower center',\n",
    "    bbox_to_anchor=(0.5, 0.0),\n",
    "    ncol=2,\n",
    "    frameon=True\n",
    ")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,6))\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "for name, result in calculations.items():\n",
    "    x_vals = np.arange(result[\"kl_to_oracle_at_timestep\"].size) + 1\n",
    "    plt.plot(\n",
    "        x_vals,\n",
    "        result[\"kl_to_oracle_at_timestep\"],\n",
    "        label=f\"{name}\",\n",
    "        marker=markers[name],\n",
    "        markevery=1,\n",
    "    )\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('KL Divergence to Oracle')\n",
    "plt.xlim(0, length + 1)\n",
    "plt.ylim(-5.0, 20.0)\n",
    "plt.legend(\n",
    "    loc='lower center',\n",
    "    bbox_to_anchor=(0.5, 0.0),\n",
    "    ncol=3,\n",
    "    frameon=True\n",
    ")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,6))\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "for name, result in calculations.items():\n",
    "    x_vals = np.arange(result[\"entropy_at_timestep\"].size) + 1\n",
    "    plt.plot(\n",
    "        x_vals,\n",
    "        result[\"entropy_at_timestep\"],\n",
    "        label=f\"{name}\",\n",
    "        marker=markers[name],\n",
    "        markevery=1,\n",
    "    )\n",
    "plt.xlabel('Episode')\n",
    "plt.ylabel('Entropy')\n",
    "plt.xlim(0, length + 1)\n",
    "plt.ylim(-0.2, 1.3)\n",
    "plt.legend(\n",
    "    loc='lower center',\n",
    "    bbox_to_anchor=(0.5, 0.0),\n",
    "    ncol=3,\n",
    "    frameon=True\n",
    ")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, result in results.items():\n",
    "    if name == \"meta-RL\":\n",
    "        count = {}\n",
    "        for a, r in zip(result[\"actions\"], result[\"returns\"]):\n",
    "            key = (a[0].item(), r[0].item(), a[1].item(), r[1].item(), a[2].item(),)\n",
    "            if key in count:\n",
    "                count[key] += 1\n",
    "            else:\n",
    "                count[key] = 1\n",
    "        for k, v in count.items():\n",
    "            print(f\"{k}: {v}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, result in results.items():\n",
    "    if name == \"meta-RL\":\n",
    "        count = {}\n",
    "        l = 28\n",
    "        for a, r, na in zip(result[\"actions\"][:, :-1], result[\"returns\"][:, :-1], result[\"actions\"][:, 1:]):\n",
    "            ind = \"same\" if a[l].item()==na[l].item() else \"diff\"\n",
    "            key = (r[l].item(), ind)\n",
    "            if key in count:\n",
    "                count[key] += 1\n",
    "            else:\n",
    "                count[key] = 1\n",
    "        for k, v in count.items():\n",
    "            print(f\"{k}: {v}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "xNdF5r8H8F_e",
    "Ff7I5RPB7-DQ",
    "eF0w5k4_8NBc"
   ],
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
