{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "-hfW094ImwvR",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "PioFoFBNmV9A",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import random\n",
    "import hashlib\n",
    "import gc\n",
    "from contextlib import contextmanager\n",
    "from types import SimpleNamespace\n",
    "from functools import partial\n",
    "import copy\n",
    "import time\n",
    "import math\n",
    "import itertools\n",
    "import tempfile\n",
    "import uuid\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import gymnasium as gym\n",
    "from gymnasium.envs.registration import register\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from collections import defaultdict\n",
    "from collections import deque\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForCausalLM,\n",
    "    StoppingCriteria,\n",
    "    StoppingCriteriaList,\n",
    "    PreTrainedTokenizer,\n",
    ")\n",
    "import transformers\n",
    "from accelerate import Accelerator\n",
    "import huggingface_hub\n",
    "from peft import PeftModel, get_peft_model, LoraConfig\n",
    "\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "eJ1jKzOPnAA7",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Define Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M7UUL9juYGm4"
   },
   "outputs": [],
   "source": [
    "class Environment:\n",
    "\n",
    "    def __init__(self, make_kwargs, random_action_prob=0.0):\n",
    "        self.make_kwargs = make_kwargs\n",
    "        self.random_action_prob = random_action_prob\n",
    "\n",
    "        self.env = gym.make(**self.make_kwargs)\n",
    "\n",
    "        self.done = True\n",
    "        self.observations = []\n",
    "        self.rewards = []\n",
    "\n",
    "        self.record = False\n",
    "        self.video = []\n",
    "        self.text = []\n",
    "\n",
    "        self.action_map = {}\n",
    "\n",
    "    def reset(self, record=False, **kwargs):\n",
    "        self.record = record\n",
    "        observation, _ = self.env.reset(**kwargs)\n",
    "        observation_mapped = self.map_observation(observation)\n",
    "        self.done = False\n",
    "\n",
    "        self.observations = []\n",
    "        self.rewards = []\n",
    "        self.observations.append(observation_mapped)\n",
    "\n",
    "        if self.record:\n",
    "            self.video = []\n",
    "            self.text = []\n",
    "            self.video.append(self.env.render())\n",
    "            self.text.append(f\"observation {observation_mapped}\")\n",
    "\n",
    "        return observation_mapped\n",
    "\n",
    "    def step(self, action, **kwargs):\n",
    "        if self.is_action_valid(action):\n",
    "            action_mapped = self.map_action(action)\n",
    "            random_action = self.random_action_prob > random.random()\n",
    "            action_taken = self.env.action_space.sample() if random_action else action_mapped\n",
    "            observation, reward, terminated, truncated, _ = self.env.step(action_taken, **kwargs)\n",
    "            observation_mapped = self.map_observation(observation)\n",
    "            if random_action:\n",
    "                observation_mapped = observation_mapped + '*'\n",
    "            reward_mapped = self.map_reward(reward)\n",
    "        else:\n",
    "            observation_mapped, reward_mapped, terminated, truncated = self.invalid_action_result(action)\n",
    "\n",
    "        self.done = terminated or truncated\n",
    "\n",
    "        self.observations.append(observation_mapped)\n",
    "        self.rewards.append(reward_mapped)\n",
    "\n",
    "        if self.record:\n",
    "            self.video.append(self.env.render())\n",
    "            self.text.append(f\"action {action} observation {observation_mapped} reward {reward_mapped} terminated {terminated} truncated {truncated}\")\n",
    "\n",
    "        return observation_mapped, reward_mapped, terminated, truncated\n",
    "\n",
    "    def is_done(self):\n",
    "        return self.done\n",
    "\n",
    "    def map_action(self, action):\n",
    "        return self.action_map[action]\n",
    "\n",
    "    def map_reward(self, reward):\n",
    "        return float(reward)\n",
    "\n",
    "    def map_observation(self, observation):\n",
    "        return str(observation)\n",
    "\n",
    "    def is_action_valid(self, action):\n",
    "        return action in self.action_map.keys()\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "    @staticmethod\n",
    "    def pick_from_dict(map_dict, key):\n",
    "        try:\n",
    "            value = map_dict[key]\n",
    "        except KeyError:\n",
    "            key_list = list(map_dict.keys())\n",
    "            hash_bytes = hashlib.sha256(key.encode(\"utf-8\")).digest()\n",
    "            hash_int = int.from_bytes(hash_bytes, byteorder=\"big\")\n",
    "            index = hash_int % len(key_list)\n",
    "            corrected_key = key_list[index]\n",
    "            value = map_dict[corrected_key]\n",
    "        return value\n",
    "\n",
    "    def sample_action(self):\n",
    "        return random.choice(list(self.action_map.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VHYCtNyT9vG1"
   },
   "outputs": [],
   "source": [
    "class MultiArmedBanditGymEnv(gym.Env):\n",
    "    \"\"\"\n",
    "    Multi-Armed Bandit Environment with Bernoulli rewards and its own random number generator (rng).\n",
    "    \"\"\"\n",
    "    metadata = {\"render_modes\": [\"human\"]}\n",
    "\n",
    "    def __init__(self, k=10, probs=None, seed=None):\n",
    "        super().__init__()\n",
    "        self.k = k\n",
    "        self.action_space = gym.spaces.Discrete(k)\n",
    "        self.observation_space = gym.spaces.Discrete(1)  # Dummy observation\n",
    "        self.seed(seed)\n",
    "\n",
    "        # Each arm has a probability of reward=1 (success), must be in [0,1]\n",
    "        if probs is None:\n",
    "            self.arm_probs = self.np_random.uniform(0, 1, k)\n",
    "        else:\n",
    "            self.arm_probs = np.array(probs)\n",
    "        self.last_action = None\n",
    "\n",
    "    def seed(self, seed=None):\n",
    "        self.np_random, _ = gym.utils.seeding.np_random(seed)\n",
    "\n",
    "    def reset(self, seed=None, options=None):\n",
    "        if seed is not None:\n",
    "            self.seed(seed)\n",
    "        self.last_action = None\n",
    "        return 0, {}  # Dummy observation, info\n",
    "\n",
    "    def step(self, action):\n",
    "        assert self.action_space.contains(action), \"Invalid action\"\n",
    "        # Reward is 1 with probability arm_probs[action], else 0\n",
    "        reward = self.np_random.binomial(1, self.arm_probs[action])\n",
    "        self.last_action = action\n",
    "        done = False\n",
    "        return 0, reward, done, False, {}  # obs, reward, terminated, truncated, info\n",
    "\n",
    "    def render(self, mode=\"human\"):\n",
    "        print(f\"Last action: {self.last_action}\")\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "register(\n",
    "    id=\"MultiArmedBandit-v0\",  # Unique identifier for the environment\n",
    "    entry_point=MultiArmedBanditGymEnv,  # module:class\n",
    "    max_episode_steps=1,  # Bandit problems are usually one-step episodes\n",
    ")\n",
    "\n",
    "class MultiArmBandit(Environment):\n",
    "\n",
    "    def __init__(self, max_episode_steps=1, arms=10, random_action_prob=0.0, seed=None):\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        probs = rng.uniform(0, 1, size=arms)\n",
    "\n",
    "        make_kwargs = {\n",
    "            \"id\": \"MultiArmedBandit-v0\",\n",
    "            \"max_episode_steps\": max_episode_steps,\n",
    "            \"k\": arms,\n",
    "            \"probs\": probs,\n",
    "        }\n",
    "        super().__init__(make_kwargs, random_action_prob=random_action_prob)\n",
    "\n",
    "        self.action_map = {str(i): i for i in range(arms)}\n",
    "\n",
    "    def reset(self, record=False, **kwargs):\n",
    "        return super().reset(record=False, **kwargs)\n",
    "\n",
    "    def map_observation(self, observation):\n",
    "        return None\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "3iMlg5e0l4vo",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class FrozenLake(Environment):\n",
    "\n",
    "    def __init__(self, max_episode_steps=100, is_slippery=False, gridmap=None, gridmap_kwargs=None, random_action_prob=0.0, seed=None):\n",
    "        if gridmap is not None:\n",
    "            self.gridmap = gridmap\n",
    "        elif gridmap_kwargs is not None:\n",
    "            self.gridmap = self.generate_unique_map(seed=seed, **gridmap_kwargs)\n",
    "        else:\n",
    "            raise ValueError(\"gridmap or gridmap_kwargs must be given\")\n",
    "        make_kwargs = {\n",
    "            \"id\": \"FrozenLake-v1\",\n",
    "            \"desc\": self.gridmap,\n",
    "            \"is_slippery\": is_slippery,\n",
    "            \"render_mode\": \"rgb_array\",\n",
    "            \"max_episode_steps\": max_episode_steps,\n",
    "        }\n",
    "        super().__init__(make_kwargs, random_action_prob=random_action_prob)\n",
    "\n",
    "        self.action_map = {\n",
    "            \"left\": 0,\n",
    "            \"down\": 1,\n",
    "            \"right\": 2,\n",
    "            \"up\": 3,\n",
    "        }\n",
    "\n",
    "    def map_reward(self, reward):\n",
    "        if reward == 0.0:\n",
    "            return None\n",
    "        new_reward = float(reward / (len(self.rewards) + 1))\n",
    "        return new_reward\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = 0.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "    @staticmethod\n",
    "    def map_is_valid(gridmap, min_hops):\n",
    "        rows = len(gridmap)\n",
    "        cols = len(gridmap[0])\n",
    "        board = [list(row) for row in gridmap]\n",
    "    \n",
    "        found_start = False\n",
    "        for r in range(rows):\n",
    "            for c in range(cols):\n",
    "                if board[r][c] == \"S\":\n",
    "                    found_start = True\n",
    "                    state = r * cols + c\n",
    "                    result = FrozenLake.find_path_to_goal(gridmap, state)\n",
    "                    if result is None:\n",
    "                        return False  # No path to goal from this 'S'\n",
    "                    _, actions = result\n",
    "                    if len(actions) < min_hops:\n",
    "                        return False  # Path exists but is too short\n",
    "        return found_start  # True if at least one 'S', otherwise False\n",
    "\n",
    "    @staticmethod\n",
    "    def generate_map(min_width=4, max_width=4, min_height=4, max_height=4, hole_prob=0.7, start_pos=0, start_pos_prob=None, goal_pos=15, goal_pos_prob=None, rng=None, seed=None):\n",
    "        if rng is None:\n",
    "            rng = random.Random(seed)\n",
    "\n",
    "        width = rng.randint(min_width, max_width)\n",
    "        height = rng.randint(min_height, max_height)\n",
    "\n",
    "        map_index = ['F'] * (width * height)\n",
    "        avalible_index = list(range(width * height))\n",
    "\n",
    "        if type(start_pos) is int:\n",
    "            start_pos = [start_pos]\n",
    "\n",
    "        if start_pos is None and start_pos_prob is None:\n",
    "            start_pos = [rng.choice(avalible_index)]\n",
    "        elif start_pos is None and start_pos_prob is not None:\n",
    "            start_pos = []\n",
    "            for i in avalible_index.copy():\n",
    "                if rng.random() < start_pos_prob:\n",
    "                    start_pos.append(i)\n",
    "\n",
    "        for p in start_pos:\n",
    "            map_index[p] = 'S'\n",
    "            avalible_index.remove(p)\n",
    "\n",
    "        if type(goal_pos) is int:\n",
    "            goal_pos = [goal_pos]\n",
    "\n",
    "        if goal_pos is None and goal_pos_prob is None:\n",
    "            goal_pos = [rng.choice(avalible_index)]\n",
    "        elif goal_pos is None and goal_pos_prob is not None:\n",
    "            goal_pos = []\n",
    "            for i in avalible_index.copy():\n",
    "                if rng.random() < goal_pos_prob:\n",
    "                    goal_pos.append(i)\n",
    "\n",
    "        for p in goal_pos:\n",
    "            map_index[p] = 'G'\n",
    "            avalible_index.remove(p)\n",
    "\n",
    "        for i in avalible_index.copy():\n",
    "            if rng.random() < hole_prob:\n",
    "                map_index[i] = 'H'\n",
    "                avalible_index.remove(i)\n",
    "\n",
    "        map = []\n",
    "        for i in range(height):\n",
    "            row = ''.join(map_index[i*width:(i+1)*width])\n",
    "            map.append(row)\n",
    "\n",
    "        return map\n",
    "\n",
    "    @staticmethod\n",
    "    def generate_valid_map(min_hops=0, seed=None, **kwargs):\n",
    "        rng = random.Random(seed)\n",
    "\n",
    "        while True:\n",
    "            map = FrozenLake.generate_map(rng=rng, seed=None, **kwargs)\n",
    "            if FrozenLake.map_is_valid(map, min_hops):\n",
    "                return map\n",
    "\n",
    "    @staticmethod\n",
    "    def generate_unique_map(other_gridmaps=[], **kwargs):\n",
    "        while True:\n",
    "            gridmap = FrozenLake.generate_valid_map(**kwargs)\n",
    "            if gridmap not in other_gridmaps:\n",
    "                return gridmap\n",
    "\n",
    "    @staticmethod\n",
    "    def find_path_to_goal(gridmap, state):\n",
    "        rows = len(gridmap)\n",
    "        cols = len(gridmap[0])\n",
    "        board = [list(row) for row in gridmap]\n",
    "    \n",
    "        start_r = state // cols\n",
    "        start_c = state % cols\n",
    "        start_pos = (start_r, start_c)\n",
    "    \n",
    "        # Find goal positions\n",
    "        goals = [(i, j) for i in range(rows) for j in range(cols) if board[i][j] == \"G\"]\n",
    "        if not goals:\n",
    "            return None\n",
    "    \n",
    "        # Ordered directions to match action_map indices\n",
    "        directions = [\n",
    "            (0, -1),  # left  -> 0\n",
    "            (1, 0),   # down  -> 1\n",
    "            (0, 1),   # right -> 2\n",
    "            (-1, 0),  # up    -> 3\n",
    "        ]\n",
    "    \n",
    "        queue = deque()\n",
    "        queue.append((start_pos, [], []))  # (current position, path_so_far, actions_so_far)\n",
    "        visited = set()\n",
    "    \n",
    "        while queue:\n",
    "            (r, c), path, actions = queue.popleft()\n",
    "            if (r, c) in goals:\n",
    "                full_path = [start_pos] + path\n",
    "                return full_path, actions\n",
    "            if (r, c) in visited:\n",
    "                continue\n",
    "            visited.add((r, c))\n",
    "            for action, (dr, dc) in enumerate(directions):\n",
    "                nr, nc = r + dr, c + dc\n",
    "                if 0 <= nr < rows and 0 <= nc < cols and board[nr][nc] != \"H\":\n",
    "                    queue.append(\n",
    "                        ((nr, nc), path + [(nr, nc)], actions + [action])\n",
    "                    )\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "w7N-aAc2NK_J"
   },
   "outputs": [],
   "source": [
    "class CliffWalking(Environment):\n",
    "\n",
    "    def __init__(self, max_episode_steps=100, is_slippery=False, random_action_prob=0.0):\n",
    "        make_kwargs = {\n",
    "            \"id\": \"CliffWalking-v1\",\n",
    "            \"is_slippery\": is_slippery,\n",
    "            \"render_mode\": \"rgb_array\",\n",
    "            \"max_episode_steps\": max_episode_steps,\n",
    "        }\n",
    "        super().__init__(make_kwargs, random_action_prob=random_action_prob)\n",
    "\n",
    "        self.action_map = {\n",
    "            \"up\": 0,\n",
    "            \"right\": 1,\n",
    "            \"down\": 2,\n",
    "            \"left\": 3,\n",
    "        }\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = -1.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "#CliffWalkingFunc = partial(\n",
    "#    CliffWalking,\n",
    "#    max_episode_steps=200,\n",
    "#    is_slippery=False,\n",
    "#    random_action_prob=0.0,\n",
    "#)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_l_c8qADCtTT"
   },
   "outputs": [],
   "source": [
    "class Taxi(Environment):\n",
    "\n",
    "    def __init__(self, max_episode_steps=100, random_action_prob=0.0):\n",
    "        make_kwargs = {\n",
    "            \"id\": \"Taxi-v3\",\n",
    "            \"render_mode\": \"rgb_array\",\n",
    "            \"max_episode_steps\": max_episode_steps,\n",
    "        }\n",
    "        super().__init__(make_kwargs, random_action_prob=random_action_prob)\n",
    "\n",
    "        self.action_map = {\n",
    "            \"down\": 0,\n",
    "            \"up\": 1,\n",
    "            \"right\": 2,\n",
    "            \"left\": 3,\n",
    "            \"pick\": 4,\n",
    "            \"drop\": 5,\n",
    "        }\n",
    "\n",
    "    def map_observation(self, observation):\n",
    "        taxi_row, taxi_col, passenger_location, destination = self.decode_observation(observation)\n",
    "        return f\"{taxi_row},{taxi_col},{passenger_location},{destination}\"\n",
    "\n",
    "    @staticmethod\n",
    "    def decode_observation(obs):\n",
    "        \"\"\"\n",
    "        Given an observation integer obs = ((taxi_row * 5 + taxi_col) * 5 + passenger_location) * 4 + destination,\n",
    "        return the tuple (taxi_row, taxi_col, passenger_location, destination).\n",
    "        \"\"\"\n",
    "\n",
    "        # 1) The last component is destination (0..3):\n",
    "        destination = obs % 4\n",
    "        rest = obs // 4\n",
    "\n",
    "        # 2) The next component is passenger_location (0..4):\n",
    "        passenger_location = rest % 5\n",
    "        rest = rest // 5\n",
    "\n",
    "        # 3) Now rest = (taxi_row * 5 + taxi_col). Recover taxi_col and taxi_row:\n",
    "        taxi_col = rest % 5\n",
    "        taxi_row = rest // 5\n",
    "\n",
    "        return taxi_row, taxi_col, passenger_location, destination\n",
    "\n",
    "    def invalid_action_result(self, action):\n",
    "        observation = f\"Invalid - Pick From {list(self.action_map.keys())}\"\n",
    "        reward = -50.0\n",
    "        terminated = False\n",
    "        truncated = False\n",
    "        return observation, reward, terminated, truncated\n",
    "\n",
    "#TaxiFunc = partial(\n",
    "#    Taxi,\n",
    "#    max_episode_steps=200,\n",
    "#    random_action_prob=0.0,\n",
    "#)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "TTX0QBHWmhvk",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "wnXcY12bmflk",
    "tags": []
   },
   "outputs": [],
   "source": [
    "N = 30\n",
    "X = 1024\n",
    "γ = 0.9\n",
    "base_X = 4096\n",
    "\n",
    "func = partial(\n",
    "    MultiArmBandit,\n",
    "    max_episode_steps=1,\n",
    "    arms=3,\n",
    "    random_action_prob=0.0,\n",
    ")\n",
    "\n",
    "bandit_config = SimpleNamespace(\n",
    "    run_name=f\"zbandit_{N}_{X}_{γ}_{uuid.uuid4()}\",\n",
    "    base_model_name=\"meta-llama/Llama-3.2-3B\",\n",
    "    load_adaptor_name=None,\n",
    "    adapter_params={\n",
    "        \"lora_r\": 32,\n",
    "        \"lora_alpha\": 32,\n",
    "    },\n",
    "    tokenizer_name=\"meta-llama/Llama-3.2-3B\",\n",
    "    rl_params={\n",
    "        \"step_discount\": 1.0,\n",
    "        \"episode_discount\": γ,\n",
    "        \"group_discount\": 0.0,\n",
    "        \"polyak_const\": 0.1,\n",
    "        \"reward_scale\": 10.0,\n",
    "    },\n",
    "    learning_rate=5e-5,\n",
    "    train_generate_length=X,\n",
    "    test_generate_length=X,\n",
    "    train_length=X,\n",
    "    train_pool_size=80,\n",
    "    train_record_size=0,\n",
    "    train_generations_per_step=0,\n",
    "    train_generate_batch_size=8,\n",
    "    test_pool_size=80,\n",
    "    test_record_size=8,\n",
    "    test_generations_per_step=1,\n",
    "    test_generate_batch_size=8,\n",
    "    trains_per_step=1,\n",
    "    train_batch_size=int(8 * base_X / X),\n",
    "    gradient_accumulation_steps=1,\n",
    "    weight_decay=0.0,\n",
    "    save_steps=0,\n",
    "    map_change_prob=1.0/N,\n",
    "    map_max_age=None,\n",
    "    train_random_prob=0.0,\n",
    "    optimizer_betas=(0.9, 0.999),\n",
    "    gradient_checkpointing=True,\n",
    "    train_env=func,\n",
    "    test_env=func,\n",
    "    action_space=[\"0\", \"1\", \"2\"],\n",
    "    warmup_length=int(20 * base_X),\n",
    "    train_steps=400,\n",
    "    train_on_last=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Zp5nMNpTpUdc"
   },
   "outputs": [],
   "source": [
    "N = 30\n",
    "X = 1024\n",
    "γ = 0.9\n",
    "base_X = 4096\n",
    "\n",
    "func = partial(\n",
    "    FrozenLake,\n",
    "    max_episode_steps=50,\n",
    "    is_slippery=False,\n",
    "    gridmap_kwargs={\n",
    "        \"other_gridmaps\": [],\n",
    "        \"min_width\": 3,\n",
    "        \"max_width\": 5,\n",
    "        \"min_height\": 3,\n",
    "        \"max_height\": 5,\n",
    "        \"hole_prob\": 0.2,\n",
    "        \"start_pos\": None,\n",
    "        \"start_pos_prob\": 0.1,\n",
    "        \"goal_pos\": None,\n",
    "        \"goal_pos_prob\": 0.1,\n",
    "        \"min_hops\": 4,\n",
    "    },\n",
    "    random_action_prob=0.0,\n",
    ")\n",
    "\n",
    "frozen_config = SimpleNamespace(\n",
    "    run_name=f\"bfrozen_{N}_{X}_{γ}_{uuid.uuid4()}\",\n",
    "    base_model_name=\"meta-llama/Llama-3.2-3B\",\n",
    "    load_adaptor_name=None,\n",
    "    adapter_params={\n",
    "        \"lora_r\": 32,\n",
    "        \"lora_alpha\": 32,\n",
    "    },\n",
    "    tokenizer_name=\"meta-llama/Llama-3.2-3B\",\n",
    "    rl_params={\n",
    "        \"step_discount\": 1.0,\n",
    "        \"episode_discount\": γ,\n",
    "        \"group_discount\": 0.0,\n",
    "        \"polyak_const\": 0.1,\n",
    "        \"reward_scale\": 400.0,\n",
    "    },\n",
    "    learning_rate=5e-5,\n",
    "    train_generate_length=X,\n",
    "    test_generate_length=X,\n",
    "    train_length=X,\n",
    "    train_pool_size=80,\n",
    "    train_record_size=0,\n",
    "    train_generations_per_step=0,\n",
    "    train_generate_batch_size=8,\n",
    "    test_pool_size=80,\n",
    "    test_record_size=8,\n",
    "    test_generations_per_step=1,\n",
    "    test_generate_batch_size=8,\n",
    "    trains_per_step=1,\n",
    "    train_batch_size=int(8 * base_X / X),\n",
    "    gradient_accumulation_steps=1,\n",
    "    weight_decay=0.0,\n",
    "    save_steps=0,\n",
    "    map_change_prob=1.0/N,\n",
    "    map_max_age=None,\n",
    "    train_random_prob=0.0,\n",
    "    optimizer_betas=(0.9, 0.999),\n",
    "    gradient_checkpointing=True,\n",
    "    train_env=func,\n",
    "    test_env=func,\n",
    "    action_space=[\"left\", \"right\", \"up\", \"down\"],\n",
    "    warmup_length=int(20 * base_X),\n",
    "    train_steps=1200,\n",
    "    train_on_last=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mL7nvvJwpUoy"
   },
   "outputs": [],
   "source": [
    "config = frozen_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZDbjL8tplFVh",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Setup & Connect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 208
    },
    "editable": true,
    "executionInfo": {
     "elapsed": 1902,
     "status": "ok",
     "timestamp": 1740606661440,
     "user": {
      "displayName": "",
      "userId": "15031539783321282003"
     },
     "user_tz": 360
    },
    "id": "_cj4aqRTmrg6",
    "outputId": "d01a8e67-91a6-4311-a17d-800ea340471c",
    "tags": []
   },
   "outputs": [],
   "source": [
    "huggingface_hub.login(token=os.environ.get('_HF_TOKEN'))\n",
    "wandb.login(key=os.environ.get('_WANDB_TOKEN'))\n",
    "\n",
    "hf_repo_name = f'username/repo_name-{config.run_name}'\n",
    "hf_repo_load_adaptor_name = f'username/repo_name-{config.load_adaptor_name}' if config.load_adaptor_name is not None else None\n",
    "\n",
    "wdb = wandb.init(project=\"project_name\", name=config.run_name, notes=\"\", dir=\"../\", save_code=True)\n",
    "wdb.config.update(config.__dict__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "37Q4B-uAnVr3",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Define Stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D3jduq6Cl4Dn"
   },
   "outputs": [],
   "source": [
    "class NumpyMemmapBuffer:\n",
    "    def __init__(self, dtype, size_increment):\n",
    "        self.file = tempfile.NamedTemporaryFile(mode='w+b')\n",
    "        self.dtype = dtype\n",
    "        self.size_increment = size_increment\n",
    "        self.storage = np.require(np.memmap(self.file, dtype=self.dtype, mode='r+', shape=(self.size_increment,)), requirements=['O'])\n",
    "        self.index = 0\n",
    "\n",
    "    def __del__(self):\n",
    "        self.storage.flush()\n",
    "        del self.storage\n",
    "        self.file.close()\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.index\n",
    "\n",
    "    def increase_size(self):\n",
    "        new_size = self.storage.shape[0] + self.size_increment\n",
    "        self.storage.resize((new_size,))\n",
    "\n",
    "    def add(self, data):\n",
    "        data = np.atleast_1d(data)\n",
    "        assert data.dtype == self.storage.dtype\n",
    "        assert len(data.shape) == 1\n",
    "        length = data.shape[0]\n",
    "        new_index = self.index + length\n",
    "        while self.storage.shape[0] < new_index:\n",
    "            self.increase_size()\n",
    "        self.storage[self.index:new_index] = data\n",
    "        self.index = new_index\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.storage[:self.index][idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "v7wLB26SpIze"
   },
   "outputs": [],
   "source": [
    "class NumpyBuffer:\n",
    "    def __init__(self, dtype, size_increment):\n",
    "        self.dtype = dtype\n",
    "        self.size_increment = size_increment\n",
    "        self.storage = np.empty(self.size_increment, dtype=self.dtype)\n",
    "        self.index = 0\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.index\n",
    "\n",
    "    def increase_size(self):\n",
    "        new_size = self.storage.shape[0] + self.size_increment\n",
    "        new_storage = np.empty(new_size, dtype=self.dtype)\n",
    "        new_storage[:self.index] = self.storage[:self.index]\n",
    "        self.storage = new_storage\n",
    "\n",
    "    def add(self, data):\n",
    "        data = np.atleast_1d(data)\n",
    "        assert data.dtype == self.storage.dtype\n",
    "        assert len(data.shape) == 1\n",
    "        length = data.shape[0]\n",
    "        new_index = self.index + length\n",
    "        while self.storage.shape[0] < new_index:\n",
    "            self.increase_size()\n",
    "        self.storage[self.index:new_index] = data\n",
    "        self.index = new_index\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.storage[:self.index][idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "pOX0bjwmmKEd",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Stream:\n",
    "    def __init__(self, tokenizer, size_increment):\n",
    "        self.struct = []\n",
    "        self.token_dtype = np.dtype([\n",
    "            (\"ids\", np.int64),\n",
    "            (\"types\", np.int64),\n",
    "            (\"episodes\", np.int64),\n",
    "            (\"groups\", np.int64),\n",
    "            (\"rewards\", np.float64),\n",
    "            (\"elements\", np.int64),\n",
    "        ])\n",
    "        self.tokens = NumpyBuffer(dtype=self.token_dtype, size_increment=size_increment)\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "        self.roles = {\n",
    "            \"environment\": {\n",
    "                \"header\": \"environment\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"begin\": {\n",
    "                \"header\": \"begin\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"thought\": {\n",
    "                \"header\": \"thought:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"observation\": {\n",
    "                \"header\": \"observation:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"action\": {\n",
    "                \"header\": \"action:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"reward\": {\n",
    "                \"header\": \"reward:\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "            \"end\": {\n",
    "                \"header\": \"end\",\n",
    "                \"footer\": self.tokenizer.eos_token,\n",
    "                },\n",
    "        }\n",
    "\n",
    "        # Precompute header and footer token IDs for each role\n",
    "        type_id = 0\n",
    "        for role_name, role in self.roles.items():\n",
    "            role[\"header_ids\"] = self.tokenizer.encode(role[\"header\"], add_special_tokens=False)\n",
    "            role[\"footer_ids\"] = self.tokenizer.encode(role[\"footer\"], add_special_tokens=False)\n",
    "            role[\"header_type\"] = type_id\n",
    "            role[\"content_type\"] = type_id + 1\n",
    "            role[\"footer_type\"] = type_id + 2\n",
    "            type_id += 3\n",
    "\n",
    "    def add_struct(self, struct):\n",
    "        tokens = self.struct_to_tokens(struct)\n",
    "        struct_length = len(self.struct)\n",
    "        tokens[\"elements\"] = [x + struct_length for x in tokens[\"elements\"]]\n",
    "        self.struct.extend(struct)\n",
    "        self.tokens.add(tokens)\n",
    "\n",
    "    def __getitem__(self, key):\n",
    "        if isinstance(key, tuple):\n",
    "            idx, mode = key\n",
    "            if mode == \"struct\":\n",
    "                return self.struct[idx]\n",
    "            elif mode == \"tokens\":\n",
    "                return self.tokens[idx]\n",
    "        else:\n",
    "            return self.struct[key]\n",
    "\n",
    "    def get_tokens_length(self):\n",
    "        return len(self.tokens)\n",
    "\n",
    "    def tokenize(self, content):\n",
    "        return self.tokenizer.encode(content, add_special_tokens=False)\n",
    "\n",
    "    def detokenize(self, content_ids):\n",
    "        return self.tokenizer.decode(content_ids, skip_special_tokens=False)\n",
    "\n",
    "    def struct_to_tokens(self, struct):\n",
    "        ids = []\n",
    "        types = []\n",
    "        episodes = []\n",
    "        groups = []\n",
    "        rewards = []\n",
    "        elements = []\n",
    "\n",
    "        for i, seg in enumerate(struct):\n",
    "            role_name = seg[\"role\"]\n",
    "\n",
    "            try:\n",
    "                role = self.roles[role_name]\n",
    "            except KeyError:\n",
    "                raise KeyError(f\"Unknown role: {role_name}\")\n",
    "\n",
    "            content_ids = self.tokenize(str(seg[\"content\"]))\n",
    "            full_ids = role[\"header_ids\"] + content_ids + role[\"footer_ids\"]\n",
    "            ids.extend(full_ids)\n",
    "            types.extend(\n",
    "                [role[\"header_type\"]] * len(role[\"header_ids\"])\n",
    "                + [role[\"content_type\"]] * len(content_ids)\n",
    "                + [role[\"footer_type\"]] * len(role[\"footer_ids\"])\n",
    "            )\n",
    "            episodes.extend([int(seg[\"episode\"])] * len(full_ids))\n",
    "            groups.extend([int(seg[\"group\"])] * len(full_ids))\n",
    "            rewards.extend([float(seg[\"reward\"])] + [0.0] * (len(full_ids) - 1))\n",
    "            elements.extend([i] * len(full_ids))\n",
    "\n",
    "        tokens = np.zeros(len(ids), dtype=self.token_dtype)\n",
    "        tokens['ids'] = ids\n",
    "        tokens['types'] = types\n",
    "        tokens['episodes'] = episodes\n",
    "        tokens['groups'] = groups\n",
    "        tokens['rewards'] = rewards\n",
    "        tokens['elements'] = elements\n",
    "\n",
    "        return tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "MSX5xtncn0gf",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Define Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XRTDLn_sv7iL"
   },
   "outputs": [],
   "source": [
    "@torch.compile(mode=\"max-autotune\")\n",
    "def argnext(x, dim):\n",
    "    reversed_input = torch.flip(x, dims=[dim])\n",
    "    cumsum = (reversed_input.cumsum(dim=dim) * 2) + reversed_input\n",
    "    _, reversed_indices = cumsum.cummax(dim=dim)\n",
    "    indices_from_right = torch.flip(reversed_indices, dims=[dim])\n",
    "    indices = (x.size(dim) - 1) - indices_from_right\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "ZLUmDZsbnufQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class CustomTrainer():\n",
    "\n",
    "    def __init__(\n",
    "            self, model_new, output_dir=\"./\", learn_rate=0.001, betas=(0.9, 0.999), weight_decay=0.0, mixed_precision=\"bf16\",\n",
    "            per_device_train_batch_size=1, gradient_accumulation_steps=1,\n",
    "            gradient_checkpointing=False, gradient_checkpointing_kwargs={},\n",
    "            save_steps=0, hf_repo_name=None, rl_params={}, stream_roles={},\n",
    "        ):\n",
    "\n",
    "        self.model_new = model_new\n",
    "        self.learn_rate = learn_rate\n",
    "        self.betas = betas\n",
    "        self.weight_decay = weight_decay\n",
    "        self.output_dir = output_dir\n",
    "        self.per_device_train_batch_size = per_device_train_batch_size\n",
    "        self.gradient_checkpointing = gradient_checkpointing\n",
    "        self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs\n",
    "        self.save_steps = save_steps\n",
    "        self.rl_params = rl_params\n",
    "        self.stream_roles = stream_roles\n",
    "        self.hf_repo_name = hf_repo_name\n",
    "        self.global_step = 0\n",
    "        self.accelerator = Accelerator(\n",
    "            mixed_precision=mixed_precision,\n",
    "            gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "            log_with=\"all\",\n",
    "            step_scheduler_with_optimizer=False,\n",
    "        )\n",
    "        self.init_trainer_done = False\n",
    "        self.init_generate_done = False\n",
    "\n",
    "    def init_trainer(self):\n",
    "        trainable_params = [p for p in self.model_new.parameters() if p.requires_grad]\n",
    "        self.optimizer = torch.optim.AdamW(params=trainable_params, lr=self.learn_rate, betas=self.betas, eps=1e-06, weight_decay=self.weight_decay)\n",
    "\n",
    "        self.lr_scheduler = transformers.get_constant_schedule(self.optimizer)\n",
    "\n",
    "        self.model_old = copy.deepcopy(self.model_new)\n",
    "\n",
    "        self.model_new, self.model_old, self.optimizer, self.lr_scheduler = self.accelerator.prepare(\n",
    "            self.model_new, self.model_old, self.optimizer, self.lr_scheduler\n",
    "        )\n",
    "\n",
    "        if self.gradient_checkpointing:\n",
    "            self.model_new.gradient_checkpointing_enable(gradient_checkpointing_kwargs=self.gradient_checkpointing_kwargs)\n",
    "\n",
    "        self.init_trainer_done = True\n",
    "\n",
    "    def init_generate(self):\n",
    "        self.model_new = self.accelerator.prepare(self.model_new)\n",
    "\n",
    "        self.init_generate_done = True\n",
    "\n",
    "    def train(self, inputs, action_ids):\n",
    "        if not self.init_trainer_done:\n",
    "            self.init_trainer()\n",
    "        torch.cuda.reset_peak_memory_stats()\n",
    "        logs = self._inner_train(inputs, action_ids)\n",
    "        peak_memory = torch.cuda.max_memory_reserved() / 1e9\n",
    "        logs[\"peak_memory\"] = peak_memory\n",
    "        #self.empty_device_cache()\n",
    "        return logs\n",
    "\n",
    "    def _inner_train(self, inputs, action_ids):\n",
    "        self.model_new.train()\n",
    "        self.model_old.eval()\n",
    "\n",
    "        all_metrics = defaultdict(list)\n",
    "\n",
    "        collate_keys = list(inputs.keys())\n",
    "        dataset = torch.utils.data.TensorDataset(*[inputs[k] for k in collate_keys])\n",
    "        dataloader = DataLoader(dataset, batch_size=self.per_device_train_batch_size, shuffle=True,)\n",
    "        dataloader = self.accelerator.prepare(dataloader)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            action_ids = torch.tensor(action_ids, dtype=torch.long)\n",
    "\n",
    "        for batch in dataloader:\n",
    "\n",
    "            batch_dict = dict(zip(collate_keys, batch))\n",
    "\n",
    "            action_ids = action_ids.to(self.accelerator.device)\n",
    "\n",
    "            with self.accelerator.accumulate(self.model_new):\n",
    "                with self.accelerator.autocast():\n",
    "                    loss, metrics = self.compute_loss(self.model_new, self.model_old, batch_dict, action_ids)\n",
    "                self.accelerator.backward(loss)\n",
    "                self.optimizer.step()\n",
    "                self.optimizer.zero_grad()\n",
    "                if self.lr_scheduler is not None:\n",
    "                    self.lr_scheduler.step()\n",
    "\n",
    "            for key, value in metrics.items():\n",
    "                all_metrics[key].append(value)\n",
    "\n",
    "        all_metrics = self.accelerator.gather_for_metrics(all_metrics)\n",
    "\n",
    "        self.polyak_update()\n",
    "\n",
    "        if self.save_steps > 0 and self.global_step % self.save_steps == 0:\n",
    "            self.save_model()\n",
    "\n",
    "        self.global_step += 1\n",
    "\n",
    "        logs = {\n",
    "            \"global_step\": self.global_step,\n",
    "            \"learning_rate\": self.lr_scheduler.get_last_lr()[0],\n",
    "        }\n",
    "        logs.update({key: torch.cat(val).mean().item() for key, val in all_metrics.items()})\n",
    "        return logs\n",
    "\n",
    "    def empty_device_cache(self):\n",
    "        if torch.cuda.is_available():\n",
    "            gc.collect()\n",
    "            torch.cuda.empty_cache()\n",
    "            gc.collect()\n",
    "\n",
    "    #@torch.compile(mode=\"max-autotune\")\n",
    "    def compute_loss(self, model_new, model_old, inputs, allowed_actions):\n",
    "        # Prepare shifted token targets and attributes\n",
    "        # Detach to avoid gradient tracking into the buffer tensors\n",
    "        ids = inputs['ids'][:, 1:].detach()       # [B, T-1]\n",
    "        roles = inputs['types'][:, 1:].detach()\n",
    "        episodes = inputs['episodes'][:, 1:].detach()\n",
    "        groups = inputs['groups'][:, 1:].detach()\n",
    "        rewards = inputs['rewards'][:, 1:].detach()\n",
    "\n",
    "        # Compute NEW model logits for the next-token prediction\n",
    "        logits_new = model_new(input_ids=inputs[\"ids\"])['logits'][:, :-1, :]  # [B, T-1, V]\n",
    "\n",
    "        with torch.no_grad():\n",
    "            # Compute OLD model Q-values (logits) for each step, mask to allowed actions\n",
    "            logits_old = model_old(input_ids=inputs[\"ids\"])['logits'][:, :-1, :]  # [B, T-1, V]\n",
    "            logits_old_allowed_actions = logits_old[:, :, allowed_actions] # [B, T-1, AV]\n",
    "\n",
    "            # Pick max logit for each allowed action\n",
    "            logits_new_allowed_actions = logits_new[:, :, allowed_actions]\n",
    "            indices_of_max_values = logits_new_allowed_actions.argmax(dim=-1)  # [B, T-1]\n",
    "            max_old_values = logits_old_allowed_actions.gather(dim=-1, index=indices_of_max_values.unsqueeze(-1)).squeeze(-1)  # [B, T-1]\n",
    "\n",
    "            # Mask to action-timesteps only\n",
    "            is_action = (roles == self.stream_roles[\"action\"][\"content_type\"])  # [B, T-1]\n",
    "            #is_onpolicy = (allowed_actions[indices_of_max_values] == ids)  # [B, T-1]\n",
    "            #is_offpolicy_action = is_action & ~is_onpolicy\n",
    "            next_action_idx = argnext(is_action, dim=-1).roll(-1, dims=-1)\n",
    "            next_is_action = torch.gather(is_action, dim=-1, index=next_action_idx)\n",
    "            next_is_action[:, -1] = False\n",
    "            valid_action = is_action & next_is_action\n",
    "            not_valid_action = ~valid_action\n",
    "\n",
    "            cumrewards = torch.cumsum(rewards.to(torch.float64), dim=-1)\n",
    "\n",
    "            next_episodes = torch.gather(episodes, dim=-1, index=next_action_idx)\n",
    "            next_groups = torch.gather(groups, dim=-1, index=next_action_idx)\n",
    "            next_cumrewards = torch.gather(cumrewards, dim=-1, index=next_action_idx)\n",
    "            next_max_old_values = torch.gather(max_old_values, dim=-1, index=next_action_idx)\n",
    "\n",
    "            # Compute reward\n",
    "            delta_cumrewards = (next_cumrewards - cumrewards).float()\n",
    "            delta_cumrewards_scaled = delta_cumrewards * self.rl_params[\"reward_scale\"]\n",
    "\n",
    "            # Compute discounting\n",
    "            is_group_end = (groups != next_groups)\n",
    "            is_episode_end = (episodes != next_episodes) | (groups != next_groups)\n",
    "            discounts = (\n",
    "                self.rl_params[\"step_discount\"] * ~is_episode_end * ~is_group_end\n",
    "                + self.rl_params[\"episode_discount\"] * is_episode_end * ~is_group_end\n",
    "                + self.rl_params[\"group_discount\"] * is_group_end\n",
    "            )\n",
    "\n",
    "            # Compute target values (Bellman)\n",
    "            target_values = delta_cumrewards_scaled + discounts * next_max_old_values\n",
    "\n",
    "        # Logits for actions actually taken\n",
    "        current_values = torch.gather(logits_new, dim=-1, index=ids.unsqueeze(-1)).squeeze(-1)  # [B, T-1]\n",
    "\n",
    "        # Value loss: MSE(current, target)\n",
    "        value_losses = (current_values - target_values.detach()).pow(2)\n",
    "        loss = value_losses[valid_action].mean()\n",
    "\n",
    "        metrics = {\n",
    "            \"value_loss\": value_losses[valid_action],\n",
    "            \"current_value\": current_values[valid_action],\n",
    "            \"target_value\": target_values[valid_action],\n",
    "        }\n",
    "\n",
    "        return loss, metrics\n",
    "\n",
    "    def polyak_update(self):\n",
    "        source = self.get_params(self.model_new, only_trainable=True)\n",
    "        target = self.get_params(self.model_old, only_trainable=False)\n",
    "        self.polyak_dict(source, target, self.rl_params[\"polyak_const\"])\n",
    "\n",
    "    @staticmethod\n",
    "    def polyak_dict(source, target, tau):\n",
    "        for k in source.keys():\n",
    "            target[k].data.copy_(tau * source[k].data + (1 - tau) * target[k].data)\n",
    "\n",
    "    @staticmethod\n",
    "    def get_params(model, only_trainable=False):\n",
    "        if only_trainable:\n",
    "            params = {name: param for name, param in model.named_parameters() if param.requires_grad}\n",
    "        else:\n",
    "            params = {name: param for name, param in model.named_parameters()}\n",
    "        return params\n",
    "    \n",
    "    def save_model(self):\n",
    "        self._save_model(self.output_dir, self.hf_repo_name)\n",
    "\n",
    "    def _save_model(self, path, hf_repo_name):\n",
    "        unwrapped_model = self.accelerator.unwrap_model(self.model_new)\n",
    "        unwrapped_model.save_pretrained(\n",
    "            path,\n",
    "            is_main_process=self.accelerator.is_main_process,\n",
    "            save_function=self.accelerator.save,\n",
    "        )\n",
    "        print(f\"Pushing model to repo: {hf_repo_name}\")\n",
    "        unwrapped_model.push_to_hub(hf_repo_name)\n",
    "\n",
    "    def generate(self, inputs, random_prob, action_ids):\n",
    "        if not self.init_generate_done:\n",
    "            self.init_generate()\n",
    "        logs = self._inner_generate(inputs, random_prob, action_ids)\n",
    "        #self.empty_device_cache()\n",
    "        return logs\n",
    "\n",
    "    def _inner_generate(self, inputs, random_prob, action_ids):\n",
    "        self.model_new.eval()\n",
    "\n",
    "        collate_keys = list(inputs.keys())\n",
    "        dataset = torch.utils.data.TensorDataset(*[inputs[k] for k in collate_keys])\n",
    "        dataloader = DataLoader(dataset, batch_size=len(dataset))\n",
    "        dataloader = self.accelerator.prepare(dataloader)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            action_ids = torch.tensor(action_ids, dtype=torch.long)\n",
    "\n",
    "        for batch in dataloader:\n",
    "\n",
    "            batch_dict = dict(zip(collate_keys, batch))\n",
    "\n",
    "            model_inputs = {\"input_ids\": batch_dict[\"ids\"]}\n",
    "\n",
    "            action_ids = action_ids.to(self.accelerator.device)\n",
    "\n",
    "            with self.accelerator.autocast():\n",
    "                with torch.no_grad():\n",
    "\n",
    "                    outputs = self.model_new(**model_inputs)\n",
    "\n",
    "                    logits = outputs['logits'][:, -1:, action_ids]\n",
    "\n",
    "                    output_ids = torch.argmax(logits, dim=-1)\n",
    "                    output_ids = action_ids[output_ids]\n",
    "\n",
    "                    random_mask = torch.rand_like(output_ids, dtype=torch.float32) < random_prob\n",
    "                    random_ids = action_ids[torch.randint_like(output_ids, 0, len(action_ids))]\n",
    "                    output_ids[random_mask] = random_ids[random_mask]\n",
    "\n",
    "            all_output_ids = self.accelerator.gather(output_ids)\n",
    "\n",
    "        return all_output_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1O_fBaSH0BC6"
   },
   "outputs": [],
   "source": [
    "class Agent:\n",
    "\n",
    "    def __init__(\n",
    "            self, hf_repo_name, run_name, action_ids,\n",
    "            learning_rate, weight_decay, train_batch_size, gradient_accumulation_steps, save_steps,\n",
    "            base_model_name, load_adaptor_name, adapter_params, optimizer_betas, rl_params, stream_roles, gradient_checkpointing\n",
    "        ):\n",
    "        self.base_model_name = base_model_name\n",
    "        self.action_ids = action_ids\n",
    "        self.action_prefix_ids = stream_roles[\"action\"][\"header_ids\"]\n",
    "        self.action_suffix_ids = stream_roles[\"action\"][\"footer_ids\"]\n",
    "\n",
    "        self.model_new = self.init_model(load_adaptor_name, adapter_params)\n",
    "        self.trainer = CustomTrainer(\n",
    "            model_new=self.model_new,\n",
    "            output_dir=f\"../trainer/{run_name}/\",\n",
    "            learn_rate=learning_rate,\n",
    "            betas=optimizer_betas,\n",
    "            weight_decay=weight_decay,\n",
    "            mixed_precision=\"bf16\",\n",
    "            per_device_train_batch_size=math.ceil(train_batch_size / gradient_accumulation_steps),\n",
    "            gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "            gradient_checkpointing=gradient_checkpointing,\n",
    "            gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
    "            save_steps=save_steps,\n",
    "            hf_repo_name=hf_repo_name,\n",
    "            rl_params=rl_params,\n",
    "            stream_roles=stream_roles,\n",
    "        )\n",
    "\n",
    "    def init_model(self, load_adaptor_name, adapter_params):\n",
    "        base_model = AutoModelForCausalLM.from_pretrained(\n",
    "            self.base_model_name,\n",
    "            torch_dtype=torch.bfloat16,\n",
    "            low_cpu_mem_usage=True,\n",
    "            device_map=\"auto\",\n",
    "            cache_dir=\"../model/\",\n",
    "            #attn_implementation=\"flash_attention_2\",\n",
    "        )\n",
    "        base_model.lm_head = torch.nn.Linear(base_model.lm_head.in_features, base_model.lm_head.out_features, bias=True, dtype=base_model.lm_head.weight.dtype)\n",
    "        \n",
    "        if load_adaptor_name is not None:\n",
    "            base_model = PeftModel.from_pretrained(base_model, load_adaptor_name)\n",
    "\n",
    "        else:\n",
    "            peft_config = LoraConfig(\n",
    "                task_type=\"CAUSAL_LM\",\n",
    "                target_modules=[\"gate_proj\", \"down_proj\", \"up_proj\", \"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],\n",
    "                modules_to_save=[\"lm_head\"],\n",
    "                r=adapter_params[\"lora_r\"],\n",
    "                lora_alpha=adapter_params[\"lora_alpha\"],\n",
    "                lora_dropout=0.0,\n",
    "                bias=\"none\",\n",
    "                use_rslora=True\n",
    "            )\n",
    "            model = get_peft_model(base_model, peft_config=peft_config, autocast_adapter_dtype=True)\n",
    "        \n",
    "        model.config.use_cache = False\n",
    "\n",
    "        return model\n",
    "\n",
    "    def generate(self, data, random_prob):\n",
    "        with torch.no_grad():\n",
    "            ids = data[\"ids\"]\n",
    "            prefix_ids = torch.tensor(self.action_prefix_ids).expand(ids.shape[0], -1)\n",
    "            prompt_ids = torch.cat([ids, prefix_ids], dim=-1)\n",
    "\n",
    "        new_data = {\"ids\": prompt_ids}\n",
    "        output_ids = self.trainer.generate(new_data, random_prob=random_prob, action_ids=self.action_ids)\n",
    "\n",
    "        result_data = output_ids.detach().cpu().tolist()\n",
    "\n",
    "        return result_data\n",
    "\n",
    "    def train(self, data):\n",
    "        metrics = self.trainer.train(data, action_ids=self.action_ids)\n",
    "        return metrics\n",
    "\n",
    "    def save(self):\n",
    "        self.trainer.save_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "frlmbuudn-oj",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Define World"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "M_DZjaDhn65a",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class World:\n",
    "    def __init__(\n",
    "        self, action_space, test_pool_size, train_pool_size, test_generate_batch_size, train_generate_batch_size, train_batch_size,\n",
    "        map_change_prob, map_max_age, train_env, test_env, train_generate_length, test_generate_length, train_length,\n",
    "        learning_rate, weight_decay, gradient_accumulation_steps, save_steps, run_name,\n",
    "        base_model_name, load_adaptor_name, adapter_params, optimizer_betas, tokenizer_name, rl_params, gradient_checkpointing,\n",
    "        train_record_size, test_record_size, train_on_last,\n",
    "    ):\n",
    "        self.test_pool_size = test_pool_size\n",
    "        self.train_pool_size = train_pool_size\n",
    "        self.test_generate_batch_size = test_generate_batch_size\n",
    "        self.train_generate_batch_size = train_generate_batch_size\n",
    "        self.train_batch_size = train_batch_size\n",
    "        self.map_change_prob = map_change_prob\n",
    "        self.map_max_age = map_max_age\n",
    "        self.train_generate_length = train_generate_length\n",
    "        self.test_generate_length = test_generate_length\n",
    "        self.train_length = train_length\n",
    "        self.train_env = train_env\n",
    "        self.test_env = test_env\n",
    "        self.train_record_size = train_record_size\n",
    "        self.test_record_size = test_record_size\n",
    "        self.train_on_last = train_on_last\n",
    "\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(\n",
    "            tokenizer_name,\n",
    "            cache_dir=\"../tokenizer/\"\n",
    "        )\n",
    "\n",
    "        wdb.define_metric(\"train/*\", step_metric=\"train/step\")\n",
    "        self.train_step = 0\n",
    "\n",
    "        wdb.define_metric(\"generate/*\", step_metric=\"generate/step\")\n",
    "        self.generate_step = 0\n",
    "\n",
    "        wdb.define_metric(\"train_games/*\", step_metric=\"train_games/episode\")\n",
    "        self.train_stats = {\n",
    "            \"episode\": 0,\n",
    "            \"count\": 0,\n",
    "            \"cumulative_reward\": 0.0,\n",
    "            \"weighted_reward\": 0.0,\n",
    "            \"episode_length\": 0,\n",
    "        }\n",
    "        self.train_log_interval = self.train_generate_batch_size\n",
    "        self.train_games = []\n",
    "        for i in range(self.train_pool_size):\n",
    "            record_game = i < self.train_record_size\n",
    "            self.train_games.append({\n",
    "                \"env\": None,\n",
    "                \"env_func\": self.train_env,\n",
    "                \"stream\": Stream(tokenizer=self.tokenizer, size_increment=self.train_generate_length),\n",
    "                \"episode\": 0,\n",
    "                \"group\": 0,\n",
    "                \"group_episode\": 0,\n",
    "                \"name\": f\"train_game_{i}\",\n",
    "                \"record\": record_game,\n",
    "            })\n",
    "            if record_game:\n",
    "                wdb.define_metric(f\"train_game_{i}/*\", step_metric=f\"train_game_{i}/episode\")\n",
    "\n",
    "        wdb.define_metric(\"test_games/*\", step_metric=\"test_games/episode\")\n",
    "        self.test_stats = {\n",
    "            \"episode\": 0,\n",
    "            \"count\": 0,\n",
    "            \"cumulative_reward\": 0.0,\n",
    "            \"weighted_reward\": 0.0,\n",
    "            \"episode_length\": 0,\n",
    "        }\n",
    "        self.test_log_interval = self.test_generate_batch_size\n",
    "        self.test_games = []\n",
    "        for i in range(self.test_pool_size):\n",
    "            record_game = i < self.test_record_size\n",
    "            self.test_games.append({\n",
    "                \"env\": None,\n",
    "                \"env_func\": self.test_env,\n",
    "                \"stream\": Stream(tokenizer=self.tokenizer, size_increment=self.test_generate_length),\n",
    "                \"episode\": 0,\n",
    "                \"group\": 0,\n",
    "                \"group_episode\": 0,\n",
    "                \"name\": f\"test_game_{i}\",\n",
    "                \"record\": record_game,\n",
    "            })\n",
    "            if record_game:\n",
    "                wdb.define_metric(f\"test_game_{i}/*\", step_metric=f\"test_game_{i}/episode\")\n",
    "\n",
    "        self.action_space = action_space\n",
    "        action_ids = []\n",
    "        for k in self.action_space:\n",
    "            tokens = self.train_games[0][\"stream\"].tokenize(k)\n",
    "            if len(tokens) != 1:\n",
    "                raise ValueError(f\"Action '{k}' tokenized into {len(tokens)} tokens: {tokens}. Max length is one.\")\n",
    "            action_ids.append(tokens[0])\n",
    "\n",
    "        stream_roles = self.train_games[0][\"stream\"].roles\n",
    "        self.prefix_length = len(stream_roles[\"action\"][\"header_ids\"])\n",
    "\n",
    "        self.agent = Agent(\n",
    "            hf_repo_name=hf_repo_name,\n",
    "            run_name=run_name,\n",
    "            action_ids=action_ids,\n",
    "            learning_rate=learning_rate,\n",
    "            weight_decay=weight_decay,\n",
    "            train_batch_size=train_batch_size,\n",
    "            gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "            save_steps=save_steps,\n",
    "            base_model_name=base_model_name,\n",
    "            adapter_params=adapter_params,\n",
    "            load_adaptor_name=load_adaptor_name,\n",
    "            optimizer_betas=optimizer_betas,\n",
    "            rl_params=rl_params,\n",
    "            stream_roles=stream_roles,\n",
    "            gradient_checkpointing=gradient_checkpointing,\n",
    "        )\n",
    "\n",
    "    @staticmethod\n",
    "    def random_slice(stream, length):\n",
    "        stream_len = stream.get_tokens_length()\n",
    "        if length < 1 or length > stream_len:\n",
    "            raise ValueError(\"Requested length must be between 1 and the length of the stream.\")\n",
    "        start = random.randint(0, stream_len - length)\n",
    "        end = start + length\n",
    "        return stream[start:end, \"tokens\"]\n",
    "\n",
    "    def get_data(self, selected_games, cutoff, last):\n",
    "        min_length = min([g[\"stream\"].get_tokens_length() for g in selected_games])\n",
    "        target_length = min(min_length, cutoff)\n",
    "        \n",
    "        if last:\n",
    "            d = np.stack([g[\"stream\"][-target_length:, \"tokens\"] for g in selected_games])\n",
    "        else:\n",
    "            d = np.stack([self.random_slice(g[\"stream\"], target_length) for g in selected_games])\n",
    "        data = {name: torch.from_numpy(d[name]) for name in d.dtype.names}\n",
    "\n",
    "        return data\n",
    "\n",
    "    def get_data_min_len(self, selected_games):\n",
    "        min_length = min([g[\"stream\"].get_tokens_length() for g in selected_games])\n",
    "        return min_length\n",
    "\n",
    "    def train(self, selected_games, length):\n",
    "        times = [time.time_ns()]\n",
    "\n",
    "        # Get data\n",
    "        data = self.get_data(selected_games, length, self.train_on_last)\n",
    "        times.append(time.time_ns())\n",
    "\n",
    "        #Train\n",
    "        train_logs = self.agent.train(data)\n",
    "        times.append(time.time_ns())\n",
    "\n",
    "        logs = {\n",
    "            \"train/step\": self.train_step,\n",
    "            \"train/time_get_data\": (times[1] - times[0]) * 1e-9,\n",
    "            \"train/time_train\": (times[2] - times[1]) * 1e-9,\n",
    "            \"train/time_total\": (times[2] - times[0]) * 1e-9,\n",
    "            \"train/sample_size\": len(selected_games),\n",
    "        }\n",
    "        logs.update({f\"train/{k}\": v for k, v in train_logs.items()})\n",
    "        wdb.log(logs)\n",
    "        self.train_step += 1\n",
    "        print(f\"Train logs: {logs}\")\n",
    "\n",
    "    def generate(self, selected_games, random_prob, length, enable_env_log):\n",
    "        times = [time.time_ns()]\n",
    "\n",
    "        for game in selected_games:\n",
    "            if (game[\"env\"] is None) or (game[\"env\"].is_done()):\n",
    "                game[\"episode\"] += 1\n",
    "                game[\"group_episode\"] += 1\n",
    "                \n",
    "                if ((game[\"env\"] is None)\n",
    "                    or (random.random() < self.map_change_prob)\n",
    "                    or (self.map_max_age is not None and game[\"group_episode\"] > self.map_max_age)\n",
    "                ):\n",
    "                    game[\"group_episode\"] = 1\n",
    "                    game[\"group\"] += 1\n",
    "                    game[\"env\"] = game[\"env_func\"]()\n",
    "                    game[\"stream\"].add_struct([{\"role\": \"environment\", \"content\": \"\", \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": 0.0}])\n",
    "\n",
    "                observation = game[\"env\"].reset(record=game[\"record\"])\n",
    "                game[\"stream\"].add_struct([{\"role\": \"begin\", \"content\": \"\", \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": 0.0}])\n",
    "                if observation is not None:\n",
    "                    game[\"stream\"].add_struct([{\"role\": \"observation\", \"content\": observation, \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": 0.0}])\n",
    "\n",
    "        times.append(time.time_ns())\n",
    "\n",
    "        if random_prob < 1.0:\n",
    "            data = self.get_data(selected_games, length - self.prefix_length, True)\n",
    "            times.append(time.time_ns())\n",
    "            action_ids_list = self.agent.generate(data, random_prob=random_prob)\n",
    "            times.append(time.time_ns())\n",
    "            action_list = self.tokenizer.batch_decode(action_ids_list, skip_special_tokens=False)\n",
    "            times.append(time.time_ns())\n",
    "        else:\n",
    "            times.append(time.time_ns())\n",
    "            action_list = [game[\"env\"].sample_action() for game in selected_games]\n",
    "            times.append(time.time_ns())\n",
    "            times.append(time.time_ns())\n",
    "\n",
    "        for game, action in zip(selected_games, action_list):\n",
    "            env = game[\"env\"]\n",
    "            game[\"stream\"].add_struct([{\"role\": \"action\", \"content\": action, \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": 0.0}])\n",
    "            next_obs, reward, terminated, truncated = env.step(action)\n",
    "            if reward is not None:\n",
    "                game[\"stream\"].add_struct([{\"role\": \"reward\", \"content\": reward, \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": reward}])\n",
    "            if next_obs is not None:\n",
    "                game[\"stream\"].add_struct([{\"role\": \"observation\", \"content\": next_obs, \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": 0.0}])\n",
    "            if terminated or truncated:\n",
    "                game[\"stream\"].add_struct([{\"role\": \"end\", \"content\": \"\", \"episode\": game[\"group_episode\"], \"group\": game[\"group\"], \"reward\": 0.0}])\n",
    "\n",
    "            if (terminated or truncated) and enable_env_log:\n",
    "                episode_length = len(game[\"env\"].rewards)\n",
    "                cumulative_reward = sum(x for x in game[\"env\"].rewards if x is not None)\n",
    "\n",
    "                if True:\n",
    "                    name = game['name']\n",
    "                    logs = {\n",
    "                        f\"{name}/episode\": game[\"episode\"],\n",
    "                        f\"{name}/group_episode\": game[\"group_episode\"],\n",
    "                        f\"{name}/episode_length\": episode_length,\n",
    "                        f\"{name}/reward\": cumulative_reward,\n",
    "                        f\"{name}/group\": game[\"group\"],\n",
    "                    }\n",
    "                    if game[\"env\"].record:\n",
    "                        video = np.stack(game[\"env\"].video, axis=0)\n",
    "                        video = np.moveaxis(video, -1, -3)\n",
    "                        wandb_video = wandb.Video(data_or_path=video, fps=3, format='gif', caption=f\"episode{game['episode']}\")\n",
    "                        logs.update({\n",
    "                            f\"{name}/video\": wandb_video,\n",
    "                        })\n",
    "                    wdb.log(logs)\n",
    "\n",
    "                if game in self.train_games:\n",
    "                    self.train_stats[\"count\"] += 1\n",
    "                    self.train_stats[\"cumulative_reward\"] += cumulative_reward\n",
    "                    self.train_stats[\"weighted_reward\"] += cumulative_reward * episode_length\n",
    "                    self.train_stats[\"episode_length\"] += episode_length\n",
    "                    if self.train_stats[\"count\"] >= self.train_log_interval:\n",
    "                        logs = {\n",
    "                            f\"train_games/episode\": self.train_stats[\"episode\"],\n",
    "                            f\"train_games/cumulative_reward\": self.train_stats[\"cumulative_reward\"] / self.train_stats[\"count\"],\n",
    "                            f\"train_games/weighted_reward\": self.train_stats[\"weighted_reward\"] / self.train_stats[\"episode_length\"],\n",
    "                            f\"train_games/episode_length\": self.train_stats[\"episode_length\"] / self.train_stats[\"count\"],\n",
    "                        }\n",
    "                        wdb.log(logs)\n",
    "                        self.train_stats[\"count\"] = 0\n",
    "                        self.train_stats[\"cumulative_reward\"] = 0.0\n",
    "                        self.train_stats[\"weighted_reward\"] = 0.0\n",
    "                        self.train_stats[\"episode_length\"] = 0.0\n",
    "                    self.train_stats[\"episode\"] += 1\n",
    "\n",
    "                if game in self.test_games:\n",
    "                    self.test_stats[\"count\"] += 1\n",
    "                    self.test_stats[\"cumulative_reward\"] += cumulative_reward\n",
    "                    self.test_stats[\"weighted_reward\"] += cumulative_reward * episode_length\n",
    "                    self.test_stats[\"episode_length\"] += episode_length\n",
    "                    if self.test_stats[\"count\"] >= self.test_log_interval:\n",
    "                        logs = {\n",
    "                            f\"test_games/episode\": self.test_stats[\"episode\"],\n",
    "                            f\"test_games/cumulative_reward\": self.test_stats[\"cumulative_reward\"] / self.test_stats[\"count\"],\n",
    "                            f\"test_games/weighted_reward\": self.test_stats[\"weighted_reward\"] / self.test_stats[\"episode_length\"],\n",
    "                            f\"test_games/episode_length\": self.test_stats[\"episode_length\"] / self.test_stats[\"count\"],\n",
    "                        }\n",
    "                        wdb.log(logs)\n",
    "                        self.test_stats[\"count\"] = 0\n",
    "                        self.test_stats[\"cumulative_reward\"] = 0.0\n",
    "                        self.test_stats[\"weighted_reward\"] = 0.0\n",
    "                        self.test_stats[\"episode_length\"] = 0.0\n",
    "                    self.test_stats[\"episode\"] += 1\n",
    "\n",
    "        times.append(time.time_ns())\n",
    "\n",
    "        logs = {\n",
    "            \"generate/step\": self.generate_step,\n",
    "            \"generate/time_init_game\": (times[1] - times[0]) * 1e-9,\n",
    "            \"generate/time_get_data\": (times[2] - times[1]) * 1e-9,\n",
    "            \"generate/time_generate\": (times[3] - times[2]) * 1e-9,\n",
    "            \"generate/time_decode\": (times[4] - times[3]) * 1e-9,\n",
    "            \"generate/time_env\": (times[5] - times[4]) * 1e-9,\n",
    "            \"generate/time_total\": (times[5] - times[0]) * 1e-9,\n",
    "            \"generate/sample_size\": len(selected_games),\n",
    "            \"generate/random_prob\": random_prob,\n",
    "        }\n",
    "        wdb.log(logs)\n",
    "        self.generate_step += 1\n",
    "        #print(f\"Generate logs: {logs}\")\n",
    "\n",
    "    def run(self, steps=None, test_generate_iterations=1, train_generate_iterations=1, train_iterations=1,\n",
    "            test_generate_random_prob=0.0, train_generate_random_prob=0.0, end_at_length=None, enable_env_log=False):\n",
    "        iter = itertools.count() if steps is None else range(steps)\n",
    "        for _ in iter:\n",
    "\n",
    "            for __ in range(train_generate_iterations):\n",
    "                selected_games = random.choices(self.train_games, k=self.train_generate_batch_size)\n",
    "                self.generate(selected_games, random_prob=train_generate_random_prob, length=self.train_generate_length, enable_env_log=enable_env_log)\n",
    "\n",
    "            for __ in range(test_generate_iterations):\n",
    "                selected_games = random.choices(self.test_games, k=self.test_generate_batch_size)\n",
    "                self.generate(selected_games, random_prob=test_generate_random_prob, length=self.test_generate_length, enable_env_log=enable_env_log)\n",
    "\n",
    "            for __ in range(train_iterations):\n",
    "                selected_games = random.choices(self.train_games, k=self.train_batch_size)\n",
    "                self.train(selected_games, length=self.train_length)\n",
    "\n",
    "            if end_at_length is not None:\n",
    "                all_games = self.train_games + self.test_games\n",
    "                if self.get_data_min_len(all_games) >= end_at_length:\n",
    "                    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "editable": true,
    "id": "8sLIUKrDoMpc",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Run World"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "c0df333307014c42af668d0e93430aae",
      "554fe2c25148490f9dad5c64e57947cc",
      "6a361090fb5543f789ea8adfb0945a1a",
      "06da1ad10f9b49a88d53534db2ca834b",
      "b89966b1b1c84357bb0d0edaef7240f3",
      "3b392e7b96c845bcb372dc45ecafe202",
      "b61adbb1079c47abb49805d21859a65b",
      "3a4b3477e82f4a5e8c328f9ddf8b32fb",
      "5a58c680ff5f480dacdc24d28dc5353c",
      "22b2b465f36a456bb640294de70b1d65",
      "ee5661fcfe8c4629bc151164dbb24bff",
      "6b1bb404dac74086a390724b7325cc31",
      "0ac5ae17b262414fa282fe773829f08d",
      "b644f37d349742a6bb5d9c4e0595d306",
      "dc1c392e19704d2fb046fcfe5553577e",
      "e42f054b1551459ca3391329d1b615a2",
      "077f8dfaf26144778f61f39a246db97a",
      "a9b0b69c143d415496517ddbe773f604",
      "6cd6370e2e0f47e6b01066501290dfd0",
      "8efb74b0211a42b6a40fa9e7a33fcea3",
      "0c0446d5395d4e2393782c6d2370d07f",
      "560a2543aab84135a5b977d394bfd403",
      "810cbadbb0214dee88b01072395f0bc6",
      "d0ea7f1437c8408b942b0c01d56094eb",
      "6ff72382c3fa461a8cab68af2d9bad56",
      "a9ade27394074a948397be58968fe58b",
      "1c05691767244bfab6e4200197753906",
      "ec4566ce307344eca4903b71263b2479",
      "408d128b284d47b3a54f97931c3f4f93",
      "b0c485c73a8a41b0b75389d1c6f01259",
      "78b9ad6c60ba43a793743a95d5e362cd",
      "851f1fef53e64a598ef7f9238f26f92a",
      "b3eb618cc10f4a8aaadf1104b87b8749",
      "2b157602f2604464a0ca6bbf5c3baf7c",
      "346f8ce45f03415d991297bf1381bd26",
      "4c75a9d0a91a45488f6beb420903485f",
      "b5a94f314df24061abd4f88510b670cf",
      "b19ec4bb799f4636bdd07faae37c62c6",
      "d880c36d9aca49b7be39ee77de3a185d",
      "d6779280f6fc4f5f87c979c980a969af",
      "f8b441c782094b8881ba47c04ae3c530",
      "f654747374704f689183cb4798f926dc",
      "ef83cd3e307f4e3ea81bd7abc71a7aca",
      "7e09b413160a4c7dbadddf664bb73f64",
      "8bd68963901f476da7d0e1da019d448d",
      "fbc14e15f757495ebfa7d3fafadcc527",
      "6f5d71d697034f6cb1cac7bc27479813",
      "3121765b978441a883fd4e97332dadbd",
      "3af01dce5ca141aa82d962c42e3b5992",
      "3c470fcc311540569eb4ab945d4fd535",
      "dbdb7708ac804f0eb18b1f1d6c90bcec",
      "20002e5f59084db388a3829eac800db6",
      "ace35051b1ec4275b7b4dc2102cb12a3",
      "56afa32243f14113b73224d481462494",
      "26a64929e2324b72b97a71dfb1249f72",
      "c2073445e9be4712a4dba6397f1221b4",
      "999123e0f01d441bb484c944f160c139",
      "95877d206cde4f8f9aa170c529404775",
      "a0d33d15abd443c48ffa2f1f54a6d15d",
      "45416db75a5a4a1d8c789fbf1b8f5ee5",
      "b3c50bc239344e7e84ffb3fbe9fac69f",
      "abfb98b802004485912a039299931f5e",
      "0064e8485eb045e5a2464b5a57347754",
      "f0db74839c5841609587e10903371815",
      "7addc057d95c4ef4b0289aabd2eba1a2",
      "6141bc360e7541d1b25287771ff1bd10",
      "9668f00a3e72487bace6a10a187c1f0c",
      "a1668b62e5514ef78546beb71b36eb78",
      "ce1bd76aa6eb430e8ffdd2c15330b0a4",
      "8a4d1e45ad484ccc8f77bb1ce43ea96e",
      "215ccfde24624b2fba47e67cbbaf27bd",
      "9184cec06fe84ef38f391da0b89b0eeb",
      "826bb8da77a14cb085ee804437edebaf",
      "95bf41f71d72413697d744d5f0f257e2",
      "4e6fc0bc13914724a43e788069baf52a",
      "2db1fa69a111492f8dfdb39af094242b",
      "20427d4dcf4f4cafbf4c2b99f74387a7",
      "451df28c720c44cdb9a965ab7d2fcab5",
      "648a2ddb70784377aa7deb90ce4c9e73",
      "b121fe28eb544890ae78729697507e9a",
      "9474b8b9cbfd4ef7bc5d1b7a5c51ad13",
      "d564a5d207584ec29beafdb7dc054a66",
      "2b67fdb82fc64e62b413df250bf9770b",
      "4499b46b9de14391acf8536d1f6a9410",
      "e65e1cc7d77544888e9bd073a21db639",
      "d0e9b63b79b940068f6b835492e5e669",
      "3dcd47982d9445278d98dfec99846df2",
      "c1c2f39a3b8c486aad3e2c782564bcc0",
      "2004e2b397e74a91ae88dcdfee63dc48",
      "3a20faff49b7429fabb89eb103c49ea8",
      "c5d12f448b7a4a8892bfde61fb83e037",
      "82ff7f13388f4ecab27d72423b95509b",
      "db056bc241e74c62a2359cc799ac1c13",
      "a8e4e8dc3273479387dac11df11c0bd1",
      "48de7e8c9a014efd8515b161fa8f68ba",
      "eec81341724a4a5f8ffb020c0833f4c0",
      "9ee1a499a6624efe9a2443872f1f8cdf",
      "1e72d428544e4bc0a211831f626c8288",
      "89d43ac4fa2f43d49980e90587ae556e",
      "97bf80ad85ed441eb76f8520053f9f38",
      "3b9aacb4fdbf401e83ce0e978f39343e",
      "3238c0d6139b4b418d155d3c8fe10dcb",
      "bbf52082a334411b818abee2ed763575",
      "b74f86b5e8bd46558f7ed5dd8bd84a39",
      "f004e5dc12fa436589bf388565a7661a",
      "d1ab6dd391174c7599770fcbd69a1e5a",
      "70f21f0ae6674f798e1392127a457094",
      "ffbbc1b20f4444a2a3bcb7eaa3d00f75",
      "cd3d5bf454e44a15994b7579a25297ae",
      "0529502e12d54d738c183f5ec767c080",
      "2e342db7a60e46ef8715bc2799f60fae",
      "82ddd28517f34441a31210c56dc6e492",
      "11792f4c40ba409996b17bd199a8070b",
      "4f1cbda40e114329a18100487aa11b9d",
      "c5c7c3e5661d4c0ba61930d17449680a",
      "0159077191204ba4ad6d977eec335bc5",
      "f39bb8572b014e899de2e6f4bece66c0",
      "582b345c50c044429b2aef7a7b19c1e4",
      "6699a2f3bc82415bb6a28bf1fbe943ed",
      "2be4b6c867a5457d8fa5b046beb0d0f8",
      "a9dba2bc4abb4b5996505d1507fb8c6e",
      "ad336369668946fe9749759d39a2af55",
      "20fc850c293b43c8b2782b9fd40cae91",
      "f21c2bab5e354882980db6ed03714579",
      "5130f0c81e2c4378881ec6b627f87322",
      "b760fd722ea64d82a6825bd0683cc405",
      "2d292223f89e47bfbc9ee535daafff0b",
      "3437ed92b4e8493f8161de824bd81fe3",
      "e27f5352f6274bc39707ddc9ad79f9a6",
      "6ec1cb0589e5465ab8537aa153b6d07d",
      "1604ea6db17440fc9b2b780d5fcea06d",
      "d17987dfe9074eaea8d52a5476e33cc9",
      "203ec338923646e08e2ccc0cf9324c73",
      "2e6e8115bc2949a3b1dc492b87d73f8f",
      "24a0763983894d0182a3790e416e6a27",
      "0ff8049a644b443f838780810d71eb09",
      "8c3842f70d634b4cbee04ab57e0a1f6b",
      "466833a5507143a391f90ffa843ea557",
      "a9d8216d8ca549faaaf360e7908a39a5",
      "afe9db4aad774187bd66868e6bd62269",
      "473afdf10fac4dbd82655f79f52dec41",
      "7df639bb912e4f21b8ce43413d94f195",
      "d55dfe14104341779f8860611b2d22c6",
      "9cbb696a040c482d9a3fb4ba9a3c8410",
      "ff7a45c7f49147c6a44e04caf36e743a",
      "7ff3e060d8ab4d9a9390488626e5da0c",
      "7427947f900d4322827fb21e85119f0d",
      "ce304923239d484f80e8db3e96f4228f",
      "13b927d01ef2435ca91f3736be44d843",
      "54b4aab6a3c34853857364a93fb1e6b4",
      "4622742c9ef042ed8f6213e053edf3b5",
      "bc7acd387b57415581a751db08262856",
      "775c042a121041bfb0d3c0faa1b341c6",
      "1832d92fde2c495da51808a5673f2c74",
      "1f6dfa0908064887a9a13bf98bd40c74",
      "9220c93f09214175ab8811cbb78e036b",
      "756fdd95b893463f9f9622c69bcbbc7e",
      "0010d235d6e44d16b53a9b976effbf66",
      "4daf6dd6302545c181e5dbb5f287f1bc",
      "e09bc00394064bfb85ed928ed238e058",
      "31e0246880634748a17b153278c1f6bb",
      "b4dfd82357b04220806f774d687a6f88",
      "2c6c146eff2f4373b60a58395c0d0f79",
      "274135b983d34a548c585b40f4124ce7",
      "3d1275321f394079868fda95faf06019",
      "433aaf5b889344e7a2dd3a34f325de79",
      "f521d28f2b6b4514ab2baa5825e54b24",
      "cded05320ff84cfdb5a2316fd841216a",
      "4831f9369f4049079a2ceae8cae6e9a9",
      "e03123212da44c3a8c832cfb2aaffa48",
      "db55049285e24859a5715780ca20217e",
      "3ef164d6d2244fd7a3ea937c734e7363",
      "c6d4cbf67e2745b2a70ca77bbd0cfd5d",
      "fd74605ba217482d814ea35b6d03f0ed",
      "670232727b6840cdad76dc145a45d151",
      "68aa0709ef6f4c6f99c0d7c8766c3878",
      "0c770b59d086433694b51ce1c5f73f6f",
      "0f4a142cda80499b86769f9343959925",
      "08708f385a864597aa07e8ea7df2c1e8",
      "28f173ead6f345eeabc37e790131059b",
      "af74b8db5dd44385926f51c4ab3492fb",
      "2e68b56ec3134e7aaf2272510356dddc",
      "42bcd0e7c55b4955b27b07c99bc4c7bb",
      "2a761c40364c4ab39b78d56e85b490c6",
      "46c52826a63c486b9d7db1d2bc6c3082",
      "2288beefb2044757b082b3b65ac29384",
      "c64e7ea972144677821a5b4e747f3e13",
      "809fccdc028e4cee8357155b7ad4a7a1",
      "dcc442aca5e24bd189a9a4a4fed3e79d",
      "2da1bddd28444959992a09b6d0137e22",
      "3801e2fbdbb0471384f97b7b24bf5415",
      "6205c25290ac4398a0f9021dc9acfce3",
      "690707b0dcba4f038cc6a82400cb6e1f",
      "00ca41f65067430aba58ef47c110e701",
      "acb88e259bf342d180af85ce58addbdd",
      "c39a2e19a2d24944a72c5d361d8d7bbe",
      "00ac76d9afe54bfcbbaebf224233cce9",
      "a5654a0a4f1b416ca4d66874fe16a7eb",
      "178699ab35ba4fb194eac5ef583b4cf9",
      "a8d435a7db764a3090a8aedbc3bb9d0a",
      "cf939d1c3dbe496fa9f5a62f89258259",
      "eb28130d012f4d7c84512e678af17a9e",
      "4641bb994c464f21936e07c0505f00cc",
      "bc715766a25f4fb3a271210c60e10258",
      "0af5e01cb9a84758aebe5e6dc9ac5e38",
      "5e539c5810a042aa9e20eb96fbbef206",
      "99dceb2b2daa4d48af2065eed4747dc6",
      "0a34a0f76d5640d9bec947b2568814fa",
      "dec3a0399f984a36a73d66abd087e58f",
      "0a3cfa30c7d44ce79ac58acb81341dd6",
      "57720cd13c9a49c395fcf28617477c32",
      "5843ae77abcd4b3bb91ced2e44bcb39c",
      "06d43f0186974e559c20ba4aaed76247",
      "d4a669a1390c4cb2b1d8c05bdfb6f6eb",
      "5b1cfd5df17042dd8158bec5f9acde8d",
      "6cb93ac09b7c44ffb32c045efb3ff5a7",
      "6036379a8e3e4ce798ab6c526cebec52",
      "4537baca1b6d478d90b86f9bf65308a7",
      "59e1d13d07334474ba8500867e0d25f8",
      "e38b7a8475054ea289e4c19b50873c5b",
      "2679f8b47a7b46fca327ee7dc760327d",
      "fa81d27ca5924801a433454306d3d858",
      "9e7d3e1f2b6446288b0110f00fb3cd85",
      "f67aa16ad63647a5b3926f4f181a8169",
      "5babee14563a4462b1a41c6dbff4b9e7",
      "3ca1efde7be94a1ab14e684372c564cc",
      "02062d264144471f96a08bbf8c80c319",
      "c7b9cf1114ca47a8a7072d8b3c9c951b",
      "a0a6937e7b4241cdb084d83a674b1d3f",
      "996f3de3f79c4001aa062f42cba4d4eb",
      "69ed58e2e84042e8be545ac9c4c95280",
      "e60bf3cbbcab4da6a1b3200a81cb7eff",
      "0202e77ed83e454085ae204478155c6d",
      "78b285cd89264f799322c7865b579cec",
      "2772e8d3269c4a74b729af1f6852872f",
      "561b606730dc45239762a84c989653ea",
      "455f229b3bdf425c835d6a1fd0e1f6dd",
      "35351148751f45958c950dc4bd4103d3",
      "2606694c94514832b6c7a38565315acf",
      "253b9f4119ba40d1aeb3a51bf5a21330",
      "eb2a9bb8b3754565813a1502e5fdbb56",
      "3224f34248034526a0e048a174e51bdc",
      "de9cc38072664f11ba924a4726458205",
      "1b9ebecd9da34a3b9087392a6e9f2a14",
      "b68c4503c36c4aecb9451e0ecd4be961",
      "364aa610dfb44226bbf03464852c5c30",
      "68738229143345f3b54ba20a1ac02217",
      "e01d04331e494e51a068f958a9675ac7",
      "86a95f3622bb49ae9d12ed5417722642",
      "5ec37b0ed3c04006be6ced6182419497",
      "89994c754744408e96272326679aec80",
      "f3973e31e0514dc19b22d97fa14db858",
      "1e551f15f79444878004883c5f31fc30",
      "cff53a3dc0874935ba830ab7b3cbfc64",
      "e78c0091050840c0befc21efd97fedce",
      "d2a997c6997444598ade4efe0bccd717",
      "19d026dab4b440af9066aaa6a813441d",
      "844e9e632e8a4a3bb661645d1033722d",
      "ec2cd1058eba4e84bb4f5655f23fa6b4",
      "40234a1bc21440bea70bcff61c2e270d",
      "ada3d0c8815e44219555eeb25000ac74",
      "ab8c55a530df4707ab98d34efc8b6add",
      "170d4d1f05b8427d831f5aa03da2ad93",
      "5c7d6b334e94449da421e05a17f77470",
      "386f3d71b4f14ceb92f71719eb8d96c3",
      "51ea12db1483490f91165fa7a311cc0e",
      "d779d689b35c4ac3b836d87a3bca98db",
      "509eff9d4dba400aa08a4de4674bc3be",
      "e315db183e264bbb8a57cfe6df79115d",
      "0ce383bcfa5c458082a6dd2e1ba87bcc",
      "fa65761beb834ec7990867463f759fb4",
      "7bf1ffa4fb034bac94b7f891fa4c7aca",
      "7de87fdb73bd4eebba81b30f0ffeb13d",
      "e3e7bf5b006242a2a5d3f5bde6a3a345",
      "b366fccc3854406ababc642602461a3b",
      "8f7badeda30f47918bcefbe4134085ee"
     ]
    },
    "editable": true,
    "id": "LMyy9a2tnvgl",
    "outputId": "676fb324-558c-4d04-8de0-e3e5d3ccc761",
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(\"\")\n",
    "print(\"Config...\")\n",
    "print(config)\n",
    "\n",
    "print(\"\")\n",
    "print(\"Building...\")\n",
    "world = World(\n",
    "    action_space=config.action_space,\n",
    "    test_pool_size=config.test_pool_size,\n",
    "    train_pool_size=config.train_pool_size,\n",
    "    test_generate_batch_size=config.test_generate_batch_size,\n",
    "    train_generate_batch_size=config.train_generate_batch_size,\n",
    "    train_batch_size=config.train_batch_size,\n",
    "    map_change_prob=config.map_change_prob,\n",
    "    map_max_age=config.map_max_age,\n",
    "    train_env=config.train_env,\n",
    "    test_env=config.test_env,\n",
    "    train_generate_length=config.train_generate_length,\n",
    "    test_generate_length=config.test_generate_length,\n",
    "    train_length=config.train_length,\n",
    "    learning_rate=config.learning_rate,\n",
    "    weight_decay=config.weight_decay,\n",
    "    gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
    "    save_steps=config.save_steps,\n",
    "    run_name=config.run_name,\n",
    "    base_model_name=config.base_model_name,\n",
    "    load_adaptor_name=hf_repo_load_adaptor_name,\n",
    "    adapter_params=config.adapter_params,\n",
    "    optimizer_betas=config.optimizer_betas,\n",
    "    tokenizer_name=config.tokenizer_name,\n",
    "    rl_params=config.rl_params,\n",
    "    gradient_checkpointing=config.gradient_checkpointing,\n",
    "    train_record_size=config.train_record_size,\n",
    "    test_record_size=config.test_record_size,\n",
    "    train_on_last=config.train_on_last,\n",
    ")\n",
    "\n",
    "print(\"\")\n",
    "print(\"Warming up...\")\n",
    "world.run(\n",
    "    steps=None,\n",
    "    test_generate_iterations=1,\n",
    "    train_generate_iterations=1,\n",
    "    train_iterations=0,\n",
    "    test_generate_random_prob=1.0,\n",
    "    train_generate_random_prob=1.0,\n",
    "    end_at_length=config.warmup_length,\n",
    ")\n",
    "struct = world.train_games[0][\"stream\"][:, \"struct\"]\n",
    "print(\"\")\n",
    "for i, s in enumerate(struct):\n",
    "    print(f\"{i}: {s}\")\n",
    "\n",
    "print(\"\")\n",
    "print(\"Running...\")\n",
    "world.run(\n",
    "    steps=config.train_steps,\n",
    "    test_generate_iterations=config.test_generations_per_step,\n",
    "    train_generate_iterations=config.train_generations_per_step,\n",
    "    train_iterations=config.trains_per_step,\n",
    "    test_generate_random_prob=0.0,\n",
    "    train_generate_random_prob=config.train_random_prob,\n",
    "    enable_env_log=True,\n",
    ")\n",
    "\n",
    "world.agent.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "SamItMOlAPVv",
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "-hfW094ImwvR",
    "eJ1jKzOPnAA7",
    "TTX0QBHWmhvk",
    "ZDbjL8tplFVh",
    "37Q4B-uAnVr3",
    "MSX5xtncn0gf"
   ],
   "gpuType": "T4",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
