{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ae26dd79",
   "metadata": {},
   "source": [
    "# Distributional Active Inference (DAIF)\n",
    "\n",
    "**Supplementary Material: Reference Implementation**\n",
    "\n",
    "---\n",
    "\n",
    "This notebook provides a complete, self-contained implementation of the DAIF algorithm for continuous control tasks from DeepMind Control Suite. The code is organized into the following sections:\n",
    "\n",
    "1. **Environment Setup** — Virtual environment and dependencies \n",
    "2. **Setup & Configuration** — Imports, hyperparameters, and reproducibility\n",
    "3. **Environment Wrapper** — DeepMind Control Suite integration with Gymnasium\n",
    "4. **Replay Buffer** — Experience storage using TensorDict\n",
    "5. **Network Architectures** — Actor, Quantile Critic, and Ensemble modules\n",
    "6. **DAIF Loss Function** — The core distributional critic loss\n",
    "7. **Agent Implementation** — Full DAIF agent with actor-critic updates\n",
    "8. **Training Loop** — Experiment runner with evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26377afc",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 1. Environment Setup\n",
    "\n",
    "We begin by setting up a virtual environment and installing all necessary dependencies. This ensures that the code runs in an isolated environment with the correct package versions.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "194683b0",
   "metadata": {},
   "source": [
    "```bash\n",
    "# Create and activate a virtual environment\n",
    "conda create -n daif python=3.10 -y\n",
    "conda activate daif\n",
    "\n",
    "# Install required packages\n",
    "pip install gymnasium[mujoco]==1.1.1 torch==2.8.0 torchrl==0.8.1 tqdm==4.67.1 dm-control==1.0.31 mujoco==3.3.3\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fad9dae6",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 2. Setup & Configuration\n",
    "\n",
    "Importing required libraries and defining all hyperparameters in a single location for easy modification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f56fc7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Core libraries\n",
    "import torch\n",
    "from torch import nn as nn, func as thf\n",
    "from torch.distributions import Normal, Gamma, TransformedDistribution\n",
    "\n",
    "import functools\n",
    "from tensordict import TensorDict\n",
    "from torchrl.data import (\n",
    "    LazyMemmapStorage,\n",
    "    LazyTensorStorage,\n",
    "    TensorDictReplayBuffer,\n",
    ")\n",
    "import numpy as np\n",
    "\n",
    "# Environment\n",
    "from gymnasium import core, spaces, wrappers\n",
    "import dm_env\n",
    "from dm_control import suite\n",
    "\n",
    "# Utilities\n",
    "from collections import OrderedDict\n",
    "from typing import Tuple, Dict, Optional, Union, Any, Literal\n",
    "from abc import ABC\n",
    "from typing import TypeVar, Generic\n",
    "from collections.abc import Iterator\n",
    "\n",
    "import copy\n",
    "import gc\n",
    "import warnings\n",
    "import time\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae98d2f4",
   "metadata": {},
   "source": [
    "### Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35a8b9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reproducibility\n",
    "SEED = 1\n",
    "\n",
    "# Environment configuration\n",
    "ROBOT_NAME = \"quadruped\"\n",
    "TASK = \"run\"\n",
    "\n",
    "# Exploration noise (TD3-style)\n",
    "POLICY_NOISE = 0.1\n",
    "TARGET_POLICY_NOISE = 0.2\n",
    "TARGET_POLICY_NOISE_CLIP = 0.5\n",
    "POLICY_DELAY = 2\n",
    "\n",
    "# DAIF-specific parameters\n",
    "NUM_QUANTILES = 8\n",
    "REGULARIZATION_COEFF = 0.001\n",
    "\n",
    "# Network architecture\n",
    "DEPTH = 3\n",
    "WIDTH = 256\n",
    "\n",
    "# Optimization\n",
    "LEARNING_RATE = 3e-4\n",
    "BATCH_SIZE = 256\n",
    "GAMMA = 0.99\n",
    "TAU = 0.005\n",
    "\n",
    "# Training schedule\n",
    "MAX_STEPS = 1_000_000\n",
    "WARMUP_STEPS = 10_000\n",
    "BUFFER_SIZE = 1_000_000\n",
    "\n",
    "# Evaluation\n",
    "EVAL_FREQUENCY = 20_000\n",
    "EVAL_EPISODES = 10\n",
    "\n",
    "# Device selection\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {DEVICE}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9e5feff",
   "metadata": {},
   "source": [
    "### Seeding for Reproducibility\n",
    "\n",
    "We set seeds across all random number generators to ensure reproducible results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f9c647",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed: int) -> None:\n",
    "    \"\"\"Set random seeds for reproducibility across PyTorch and CUDA.\"\"\"\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.use_deterministic_algorithms(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5a24cad",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 3. Environment Wrapper\n",
    "\n",
    "We wrap the DeepMind Control Suite environments to provide a Gymnasium-compatible interface. This enables:\n",
    "\n",
    "- **Unified API**: Standard `step()` and `reset()` methods\n",
    "- **Observation flattening**: Dictionary observations are converted to flat vectors\n",
    "- **Action rescaling**: Actions are normalized to [-1, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1caeb692",
   "metadata": {},
   "outputs": [],
   "source": [
    "TimeStep = Tuple[Dict[str, np.ndarray], float, bool, bool, dict]\n",
    "\n",
    "\n",
    "def dmc_spec2gym_space(\n",
    "    spec: Union[dm_env.specs.Array, Dict, OrderedDict],\n",
    ") -> spaces.Space:\n",
    "    \"\"\"Convert DeepMind Control Suite specs to Gymnasium spaces.\"\"\"\n",
    "    if isinstance(spec, (OrderedDict, dict)):\n",
    "        spec = copy.copy(spec)\n",
    "        for k, v in spec.items():\n",
    "            spec[k] = dmc_spec2gym_space(v)\n",
    "        return spaces.Dict(spec)\n",
    "    elif isinstance(spec, dm_env.specs.BoundedArray):\n",
    "        return spaces.Box(\n",
    "            low=np.full(spec.shape, spec.minimum),\n",
    "            high=np.full(spec.shape, spec.maximum),\n",
    "            shape=spec.shape,\n",
    "            dtype=spec.dtype,\n",
    "        )\n",
    "    elif isinstance(spec, dm_env.specs.Array):\n",
    "        return spaces.Box(\n",
    "            low=-float(\"inf\"), high=float(\"inf\"), shape=spec.shape, dtype=spec.dtype\n",
    "        )\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Unsupported spec type: {type(spec)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d4ded5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DMCEnv(core.Env):\n",
    "    \"\"\"\n",
    "    Gymnasium wrapper for DeepMind Control Suite environments.\n",
    "    \n",
    "    This wrapper handles the conversion between DMC's TimeStep format\n",
    "    and Gymnasium's (obs, reward, terminated, truncated, info) format.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        domain_name: Optional[str] = None,\n",
    "        task_name: Optional[str] = None,\n",
    "        env: Optional[dm_env.Environment] = None,\n",
    "        task_kwargs: Optional[Dict] = None,\n",
    "        environment_kwargs: Optional[Dict] = None,\n",
    "    ) -> None:\n",
    "        task_kwargs = {} if task_kwargs is None else task_kwargs\n",
    "\n",
    "        assert (\n",
    "            \"random\" in task_kwargs\n",
    "        ), \"Please specify a seed in task_kwargs['random'] for deterministic behaviour.\"\n",
    "        assert env is not None or (\n",
    "            domain_name is not None and task_name is not None\n",
    "        ), \"You must provide either an environment or domain and task names.\"\n",
    "\n",
    "        if env is None:\n",
    "            env = suite.load(\n",
    "                domain_name=domain_name,\n",
    "                task_name=task_name,\n",
    "                task_kwargs=task_kwargs,\n",
    "                environment_kwargs=environment_kwargs,\n",
    "                visualize_reward=True,\n",
    "            )\n",
    "\n",
    "        self._env: dm_env.Environment = env\n",
    "        self.domain_name: Optional[str] = domain_name\n",
    "        self.task_name: Optional[str] = task_name\n",
    "        self.action_space: spaces.Space = dmc_spec2gym_space(self._env.action_spec())\n",
    "        self.observation_space: spaces.Space = dmc_spec2gym_space(\n",
    "            self._env.observation_spec()\n",
    "        )\n",
    "\n",
    "    def __getattr__(self, name: str):\n",
    "        return getattr(self._env, name)\n",
    "\n",
    "    def step(self, action: np.ndarray) -> TimeStep:\n",
    "        assert self.action_space.contains(action), \"Action not in action_space.\"\n",
    "\n",
    "        time_step = self._env.step(action)\n",
    "        reward: float = time_step.reward or 0.0\n",
    "        done: bool = time_step.last()\n",
    "        obs: Dict[str, np.ndarray] = time_step.observation\n",
    "\n",
    "        info: Dict = {}\n",
    "        trunc: bool = done and (time_step.discount == 1.0)\n",
    "        term: bool = done and (time_step.discount != 1.0)\n",
    "        if trunc:\n",
    "            info[\"TimeLimit.truncated\"] = True\n",
    "\n",
    "        return obs, reward, term, trunc, info\n",
    "\n",
    "    def reset(\n",
    "        self, seed: Optional[int] = None, options: Optional[Dict] = None\n",
    "    ) -> Tuple[Dict[str, np.ndarray], Dict]:\n",
    "        super().reset(seed=seed)\n",
    "        time_step = self._env.reset()\n",
    "        info: Dict = {}\n",
    "        return time_step.observation, info\n",
    "\n",
    "    def render(\n",
    "        self,\n",
    "        mode: str = \"rgb_array\",\n",
    "        height: int = 84,\n",
    "        width: int = 84,\n",
    "        camera_id: int = 0,\n",
    "    ) -> np.ndarray:\n",
    "        assert mode == \"rgb_array\", f\"Only support rgb_array mode, got {mode}.\"\n",
    "        return self._env.physics.render(height=height, width=width, camera_id=camera_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d5bc61d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_env(robot_name: str, task_name: str, seed: int):\n",
    "    \"\"\"\n",
    "    Create a fully configured environment with:\n",
    "    - Flattened observations (dictionary -> vector)\n",
    "    - Rescaled actions to [-1, 1]\n",
    "    - Seeded random number generators\n",
    "    \"\"\"\n",
    "    env = DMCEnv(\n",
    "        domain_name=robot_name, task_name=task_name, task_kwargs={\"random\": seed}\n",
    "    )\n",
    "    env = wrappers.FlattenObservation(env)\n",
    "    env = wrappers.RescaleAction(env, -1.0, 1.0)\n",
    "\n",
    "    env.reset(seed=seed)\n",
    "    env.action_space.seed(seed)\n",
    "    env.observation_space.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    return env"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c672982e",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 4. Replay Buffer\n",
    "\n",
    "The replay buffer stores experience tuples and supports efficient batch sampling. Key features:\n",
    "\n",
    "- **Lazy storage**: Memory is allocated on-demand\n",
    "- **Device-aware**: Supports both CPU (memory-mapped) and GPU storage\n",
    "- **TensorDict integration**: Structured experience storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85725cd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    \"\"\"\n",
    "    Experience replay buffer using TorchRL's TensorDictReplayBuffer.\n",
    "    \n",
    "    Supports both CPU (LazyMemmapStorage) and GPU (LazyTensorStorage) backends.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        device: torch.device,\n",
    "        storing_device: torch.device,\n",
    "        buffer_size: int,\n",
    "        print_gc_warning: bool = False,\n",
    "    ) -> None:\n",
    "        self.device = device\n",
    "        self.storing_device = storing_device\n",
    "        self.buffer_size = buffer_size\n",
    "        self.print_gc_warning = print_gc_warning\n",
    "        self.reset()\n",
    "\n",
    "    def _get_storage(\n",
    "        self, buffer_size: int, device: torch.device\n",
    "    ) -> LazyMemmapStorage | LazyTensorStorage:\n",
    "        if device == \"cpu\":\n",
    "            return LazyMemmapStorage(buffer_size, device=device)\n",
    "        elif device == \"cuda\":\n",
    "            return LazyTensorStorage(buffer_size, device=device)\n",
    "        else:\n",
    "            raise NotImplementedError(f\"No storage support for device type: {device}\")\n",
    "\n",
    "    def reset(self, buffer_size: int | None = None) -> None:\n",
    "        if buffer_size is not None:\n",
    "            self.buffer_size = buffer_size\n",
    "\n",
    "        self.data_size = 0\n",
    "        self.pointer = 0\n",
    "\n",
    "        storage = self._get_storage(self.buffer_size, self.storing_device)\n",
    "        self.memory = TensorDictReplayBuffer(storage=storage)\n",
    "\n",
    "    def add(self, experience: TensorDict) -> None:\n",
    "        try:\n",
    "            self.memory.add(experience)\n",
    "        except OSError:\n",
    "            warnings.warn(\n",
    "                \"Failed to add experience to replay buffer, triggering manual GC\",\n",
    "                stacklevel=2,\n",
    "            )\n",
    "            gc.collect()\n",
    "            self.memory.add(experience)\n",
    "\n",
    "        self.data_size = min(self.data_size + 1, self.buffer_size)\n",
    "        self.pointer = (self.pointer + 1) % self.buffer_size\n",
    "\n",
    "    def add_batch(self, batch: TensorDict) -> None:\n",
    "        try:\n",
    "            self.memory.extend(batch)\n",
    "        except OSError:\n",
    "            warnings.warn(\n",
    "                \"Failed to add experience to replay buffer, triggering manual GC\",\n",
    "                stacklevel=2,\n",
    "            )\n",
    "            gc.collect()\n",
    "            self.memory.extend(batch)\n",
    "\n",
    "        self.data_size = min(self.data_size + len(batch), self.buffer_size)\n",
    "        self.pointer = (self.pointer + len(batch)) % self.buffer_size\n",
    "\n",
    "    def sample_batch(self, batch_size: int) -> TensorDict:\n",
    "        batch = self.memory.sample(batch_size).to(self.device).clone()\n",
    "        return batch\n",
    "\n",
    "    def sample_random(self, batch_size: int) -> TensorDict:\n",
    "        return self.sample_batch(batch_size)\n",
    "\n",
    "    def sample_by_index(self, indices: list | torch.Tensor | range) -> TensorDict:\n",
    "        if isinstance(indices, range):\n",
    "            indices = torch.tensor(list(indices), device=self.storing_device)\n",
    "        elif isinstance(indices, list):\n",
    "            indices = torch.tensor(indices, device=self.storing_device)\n",
    "        return self.memory.storage[indices].to(self.device).clone()\n",
    "\n",
    "    def sample_by_index_fields(\n",
    "        self, indices: list | torch.Tensor | range, fields: list\n",
    "    ) -> TensorDict:\n",
    "        if isinstance(indices, range):\n",
    "            indices = torch.tensor(list(indices), device=self.storing_device)\n",
    "        elif isinstance(indices, list):\n",
    "            indices = torch.tensor(indices, device=self.storing_device)\n",
    "        return self.memory.storage[indices].select(fields).to(self.device).clone()\n",
    "\n",
    "    def sample_all(self) -> TensorDict:\n",
    "        return self.sample_by_index(range(self.data_size))\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return self.data_size\n",
    "\n",
    "    @property\n",
    "    def size(self) -> int:\n",
    "        return self.data_size\n",
    "\n",
    "    def save(self, path: str) -> None:\n",
    "        metadata = {\n",
    "            \"buffer_size\": self.buffer_size,\n",
    "            \"data_size\": self.data_size,\n",
    "            \"pointer\": self.pointer,\n",
    "        }\n",
    "        torch.save(metadata, path + \".metadata\")\n",
    "\n",
    "    def load(self, path: str) -> None:\n",
    "        metadata = torch.load(path + \".metadata\")\n",
    "        self.buffer_size = metadata[\"buffer_size\"]\n",
    "        self.data_size = metadata[\"data_size\"]\n",
    "        self.pointer = metadata[\"pointer\"]\n",
    "\n",
    "    def create_epoch_iterator(self, batch_size: int, n_epochs: int = 1) -> Iterator:\n",
    "        total_samples = self.data_size\n",
    "\n",
    "        def batch_generator():\n",
    "            for _ in range(n_epochs):\n",
    "                indices = torch.arange(total_samples, device=self.storing_device)\n",
    "                for i in range(0, total_samples, batch_size):\n",
    "                    batch_indices = indices[i : min(i + batch_size, total_samples)]\n",
    "                    yield self.sample_by_index(batch_indices)\n",
    "\n",
    "        self.epoch_iterator = batch_generator()\n",
    "        return self.epoch_iterator\n",
    "\n",
    "    def get_next_batch(self, batch_size: int) -> TensorDict:\n",
    "        if self.epoch_iterator is not None:\n",
    "            return next(self.epoch_iterator)\n",
    "        return self.sample_batch(batch_size)\n",
    "\n",
    "    def calculate_num_batches(self, batch_size: int) -> int:\n",
    "        return (self.data_size + batch_size - 1) // batch_size\n",
    "\n",
    "    def get_steps_and_iterator(\n",
    "        self, n_epochs: int, max_iter: int, batch_size: int\n",
    "    ) -> int:\n",
    "        if n_epochs > 0:\n",
    "            n_batches = self.calculate_num_batches(batch_size)\n",
    "            n_steps = n_epochs * n_batches\n",
    "            self.create_epoch_iterator(batch_size, n_epochs)\n",
    "        else:\n",
    "            n_steps = max_iter\n",
    "            self.epoch_iterator = None\n",
    "\n",
    "        return n_steps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45dd8490",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 5. Network Architectures\n",
    "\n",
    "### Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3dd4894",
   "metadata": {},
   "outputs": [],
   "source": [
    "def totorch(x, dtype=torch.float32, device=\"cuda\") -> torch.Tensor:\n",
    "    \"\"\"Convert numpy array or scalar to PyTorch tensor.\"\"\"\n",
    "    return torch.as_tensor(x, dtype=dtype, device=device)\n",
    "\n",
    "\n",
    "def tonumpy(x):\n",
    "    \"\"\"Convert PyTorch tensor to numpy array.\"\"\"\n",
    "    return x.data.cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3e9e54e",
   "metadata": {},
   "source": [
    "### Multi-Layer Perceptron (MLP)\n",
    "\n",
    "A standard feedforward network with ReLU activations used as a building block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bb330f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    \"\"\"\n",
    "    Multi-layer perceptron with configurable depth and width.\n",
    "    Uses ReLU activations between layers.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        dim_in: int,\n",
    "        dim_out: int,\n",
    "        depth: int,\n",
    "        width: int,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        assert depth > 0, \"Need at least one layer\"\n",
    "\n",
    "        self.activation_fn = nn.ReLU\n",
    "        width_multiplier = 1\n",
    "        effective_width = width * width_multiplier\n",
    "\n",
    "        layers = []\n",
    "\n",
    "        if depth == 1:\n",
    "            layers.append(nn.Linear(dim_in, dim_out))\n",
    "        else:\n",
    "            layers.append(nn.Linear(dim_in, width))\n",
    "            for i in range(depth - 1):\n",
    "                layers.append(self.activation_fn())\n",
    "                if i == depth - 2:\n",
    "                    layers.append(nn.Linear(effective_width, dim_out))\n",
    "                else:\n",
    "                    layers.append(nn.Linear(effective_width, width))\n",
    "\n",
    "        self.model = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7aaa697",
   "metadata": {},
   "source": [
    "### Actor Network\n",
    "\n",
    "The actor outputs deterministic actions bounded to [-1, 1] via tanh activation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60630640",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ActorNet(nn.Module):\n",
    "    \"\"\"\n",
    "    Deterministic policy network.\n",
    "    \n",
    "    Outputs actions in [-1, 1] via tanh squashing.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        dim_state: int,\n",
    "        dim_act: int,\n",
    "        depth: int = 3,\n",
    "        width: int = 256,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.dim_act = dim_act\n",
    "\n",
    "        self.arch = nn.Sequential(\n",
    "            MLP(dim_state, dim_act, depth, width),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(\n",
    "        self, x: torch.Tensor, is_training: bool | None = None\n",
    "    ) -> dict[str, torch.Tensor]:\n",
    "        out = self.arch(x)\n",
    "        return {\"action\": out}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6c69f88",
   "metadata": {},
   "source": [
    "### Quantile Critic Network\n",
    "\n",
    "The core of DAIF: a critic that outputs distributional value estimates using implicit quantile networks (IQN). For each quantile level $\\tau$, the network outputs parameters $(\\mu, \\alpha, \\beta)$ of a distributional estimate.\n",
    "\n",
    "**Architecture:**\n",
    "1. State-action embedding via linear layer + LayerNorm + ReLU\n",
    "2. Quantile embedding via cosine features: $\\phi(\\tau) = \\cos(\\pi \\cdot i \\cdot \\tau)$ for $i = 1, \\ldots, d$\n",
    "3. Hadamard product of embeddings\n",
    "4. Output layer predicting $(\\mu, \\log\\alpha, \\log\\beta)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af58450",
   "metadata": {},
   "outputs": [],
   "source": [
    "class QuantileCriticNet(nn.Module):\n",
    "    \"\"\"\n",
    "    Implicit Quantile Network (IQN) based critic.\n",
    "    \n",
    "    Outputs distributional parameters (mu, alpha, beta) for each quantile level.\n",
    "    Uses cosine embedding for quantile levels following the IQN paper.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        dim_state: int,\n",
    "        dim_act: int,\n",
    "        width: int = 256,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        self.embedding_dim = 128\n",
    "        \n",
    "        # State-action encoder\n",
    "        self.base_arch = nn.Sequential(\n",
    "            nn.Linear(dim_state + dim_act, width), \n",
    "            nn.LayerNorm(width), \n",
    "            nn.ReLU()\n",
    "        )\n",
    "        \n",
    "        # Quantile encoder (cosine features -> embedding)\n",
    "        self.tau_arch = nn.Sequential(\n",
    "            nn.Linear(self.embedding_dim, width), \n",
    "            nn.LayerNorm(width), \n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        # Combined processing\n",
    "        self.out_arch = nn.Sequential(\n",
    "            nn.Linear(width, width),\n",
    "            nn.LayerNorm(width),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        \n",
    "        # Output: mu, log_alpha, log_beta\n",
    "        self.out = nn.Linear(width, 3)\n",
    "\n",
    "        # Cosine basis for quantile embedding\n",
    "        self.const_vec = torch.from_numpy(np.arange(1, self.embedding_dim + 1)).float()\n",
    "        self.const_vec = nn.Parameter(self.const_vec, requires_grad=False)\n",
    "\n",
    "    def evidence(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"Convert log-space parameters to positive values.\"\"\"\n",
    "        return torch.exp(x)\n",
    "\n",
    "    def forward(self, xtau: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:\n",
    "        x, tau = xtau\n",
    "        \n",
    "        # Encode state-action pair\n",
    "        sa_embedding = self.base_arch(x)\n",
    "\n",
    "        # Encode quantile level using cosine features: cos(π * i * τ)\n",
    "        tau_embedding = torch.cos(tau.unsqueeze(-1) * self.const_vec * np.pi)\n",
    "        tau_embedding = self.tau_arch(tau_embedding)\n",
    "        \n",
    "        # Hadamard product of embeddings\n",
    "        x = sa_embedding.unsqueeze(1) * tau_embedding\n",
    "        x = self.out_arch(x)\n",
    "        \n",
    "        # Output distributional parameters\n",
    "        output = self.out(x)\n",
    "        mu, log_alpha, log_beta = output.chunk(3, dim=-1)\n",
    "        \n",
    "        # Ensure alpha, beta > 10 for numerical stability\n",
    "        alpha = self.evidence(log_alpha) + 10.0 + 1e-3\n",
    "        beta = self.evidence(log_beta) + 10.0 + 1e-3\n",
    "\n",
    "        return (\n",
    "            mu.squeeze(-1),\n",
    "            alpha.squeeze(-1),\n",
    "            beta.squeeze(-1),\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "345a75a9",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 5. DAIF Loss Function\n",
    "\n",
    "The DAIF loss combines quantile regression with a distributional parametric form. The loss function penalizes deviations from the target quantiles using the asymmetric quantile loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cad30659",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.compile\n",
    "def daif_loss(\n",
    "    pred_mu: torch.Tensor,\n",
    "    pred_alpha: torch.Tensor,\n",
    "    pred_beta: torch.Tensor,\n",
    "    target: torch.Tensor,\n",
    "    tau: torch.Tensor,\n",
    "    weight: torch.Tensor,\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute the DAIF distributional critic loss.\n",
    "    \n",
    "    Args:\n",
    "        pred_mu: Predicted quantile values [n_ensemble x n_batch x n_atoms]\n",
    "        pred_alpha: Concentration parameter alpha\n",
    "        pred_beta: Rate parameter beta\n",
    "        target: Target quantile values [n_batch x n_atoms]\n",
    "        tau: Quantile levels [n_batch x n_atoms]\n",
    "        weight: Importance weights for each quantile\n",
    "    \n",
    "    Returns:\n",
    "        Loss tensor [n_ensemble x n_batch x n_quantile x n_quantile]\n",
    "    \"\"\"\n",
    "    # Reshape for broadcasting\n",
    "    pred_mu = pred_mu.unsqueeze(3)  # [... n_atoms x 1]\n",
    "    pred_alpha = pred_alpha.unsqueeze(3)\n",
    "    pred_beta = pred_beta.unsqueeze(3)\n",
    "    target = target.unsqueeze(2)  # [... 1 x n_atoms]\n",
    "    eps = 1e-8\n",
    "    \n",
    "    # Handle tau shape\n",
    "    if len(tau.shape) == 1:\n",
    "        assert tau.shape[0] == pred_mu.shape[2]\n",
    "        tau = tau.view(1, 1, tau.shape[0], 1)\n",
    "    elif len(tau.shape) == 2:\n",
    "        assert tau.shape[0] == pred_mu.shape[1]\n",
    "        tau = tau.view(1, tau.shape[0], tau.shape[1], 1)\n",
    "\n",
    "    # Handle weight shape\n",
    "    if len(weight.shape) == 1:\n",
    "        assert weight.shape[0] == pred_mu.shape[2]\n",
    "        weight = weight.view(1, 1, 1, weight.shape[0])\n",
    "    elif len(weight.shape) == 2:\n",
    "        assert weight.shape[0] == pred_mu.shape[1]\n",
    "        weight = weight.view(1, weight.shape[0], 1, weight.shape[1])\n",
    "\n",
    "    # DAIF loss components\n",
    "    tau_term = torch.log(tau * (1.0 - tau))\n",
    "    digamma_term = -torch.log(pred_beta) + torch.digamma(pred_alpha)\n",
    "    difference = target - pred_mu\n",
    "    \n",
    "    # Asymmetric quantile loss weighted by alpha/beta\n",
    "    error_term = (\n",
    "        -(pred_alpha)\n",
    "        / (2 * pred_beta).clamp(min=eps)\n",
    "        * (torch.abs(difference) + (2.0 * tau - 1.0) * difference)\n",
    "    )\n",
    "\n",
    "    log_likelihood = tau_term + digamma_term + error_term\n",
    "    loss = -log_likelihood * weight\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9b33d26",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 7. Agent Implementation\n",
    "\n",
    "### Ensemble Module\n",
    "\n",
    "The ensemble module enables efficient parallel evaluation of multiple critic networks using `torch.vmap`. This provides better gradient estimates and reduces overestimation bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "533b375d",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = TypeVar(\"T\", bound=nn.Module)\n",
    "\n",
    "\n",
    "class Ensemble(Generic[T], nn.Module, ABC):\n",
    "    \"\"\"\n",
    "    Vectorized ensemble of neural networks using torch.vmap.\n",
    "    \n",
    "    Enables efficient parallel forward passes through multiple networks\n",
    "    with a single function call.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        n_members: int,\n",
    "        prototype: T,\n",
    "        models: list[T],\n",
    "        device: Literal[\"cpu\", \"cuda\"] = \"cpu\",\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.n_members = n_members\n",
    "        self.prototype = prototype\n",
    "        self.device = device\n",
    "\n",
    "        # Stack parameters from all ensemble members\n",
    "        self.params, self.buffers = thf.stack_module_state(models)\n",
    "        self.base_model = copy.deepcopy(models[0]).to(\"meta\")\n",
    "\n",
    "        # Register stacked parameters\n",
    "        for name, param in self.params.items():\n",
    "            self.register_parameter(\n",
    "                f\"stacked-{name.replace('.', '_')}\", nn.Parameter(param)\n",
    "            )\n",
    "        for name, buffer in self.buffers.items():\n",
    "            self.register_buffer(f\"stacked-{name.replace('.', '_')}\", buffer)\n",
    "\n",
    "        # Create vmapped forward function\n",
    "        def _fmodel(\n",
    "            base_model: nn.Module,\n",
    "            params: dict[str, torch.Tensor],\n",
    "            buffers: dict[str, torch.Tensor],\n",
    "            x: torch.Tensor,\n",
    "        ) -> torch.Tensor:\n",
    "            return thf.functional_call(base_model, (params, buffers), (x,))\n",
    "\n",
    "        self.forward_model = thf.vmap(\n",
    "            functools.partial(_fmodel, self.base_model), randomness=\"different\"\n",
    "        )\n",
    "\n",
    "    def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
    "        return self.forward_model(self.params, self.buffers, self.expand(input))\n",
    "\n",
    "    def expand(\n",
    "        self, x: torch.Tensor, force: bool = False\n",
    "    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n",
    "        \"\"\"Expand input to match ensemble dimension.\"\"\"\n",
    "        f = lambda x: (\n",
    "            x.expand(self.n_members, *x.shape) if (force or len(x.shape) < 3) else x\n",
    "        )\n",
    "        if hasattr(x, \"shape\"):\n",
    "            return f(x)\n",
    "        elif isinstance(x, tuple):\n",
    "            a, b = x\n",
    "            return (f(a), f(b))\n",
    "        else:\n",
    "            raise ValueError(f\"{x} is not a valid argument\")\n",
    "\n",
    "    def _get_single_member(self, index: int = 0) -> T:\n",
    "        \"\"\"Extract a single network from the ensemble.\"\"\"\n",
    "        single_model = copy.deepcopy(self.prototype)\n",
    "\n",
    "        for name, param in single_model.named_parameters():\n",
    "            stacked_param = self.params[name]\n",
    "            param.data.copy_(stacked_param[index])\n",
    "\n",
    "        for name, buffer in single_model.named_buffers():\n",
    "            stacked_buffer = self.buffers[name]\n",
    "            buffer.data.copy_(stacked_buffer[index])\n",
    "\n",
    "        return single_model\n",
    "\n",
    "    def _get_all_members(self) -> nn.ModuleList:\n",
    "        return nn.ModuleList(\n",
    "            [self._get_single_member(i) for i in range(self.n_members)]\n",
    "        )\n",
    "\n",
    "    def __getitem__(self, index: int) -> T:\n",
    "        return self._get_single_member(index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "feac31c2",
   "metadata": {},
   "source": [
    "### DAIF Actor\n",
    "\n",
    "The actor uses TD3-style exploration with Gaussian noise during training and target policy smoothing for computing Bellman targets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9e7511b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DAIFActor(nn.Module):\n",
    "    \"\"\"\n",
    "    DAIF Actor with TD3-style exploration.\n",
    "    \n",
    "    Features:\n",
    "    - Gaussian exploration noise during training\n",
    "    - Target policy smoothing for stable Bellman targets\n",
    "    - Soft target network updates\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, dim_state: int, dim_act: int) -> None:\n",
    "        super().__init__()\n",
    "        self.device = DEVICE\n",
    "        self.dim_state, self.dim_act = dim_state, dim_act\n",
    "        self._tau = TAU\n",
    "        self._gamma = GAMMA\n",
    "        self.policy_noise = POLICY_NOISE\n",
    "        self.target_policy_noise = TARGET_POLICY_NOISE\n",
    "        self.target_policy_noise_clip = TARGET_POLICY_NOISE_CLIP\n",
    "\n",
    "        self.action_limit_low = totorch(-1, device=self.device)\n",
    "        self.action_limit_high = totorch(1, device=self.device)\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self) -> None:\n",
    "        self.model = ActorNet(\n",
    "            self.dim_state, self.dim_act, depth=DEPTH, width=WIDTH,\n",
    "        ).to(self.device)\n",
    "\n",
    "        self.optim = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "        self.target = ActorNet(\n",
    "            self.dim_state, self.dim_act, depth=DEPTH, width=WIDTH,\n",
    "        ).to(self.device)\n",
    "        self.target.load_state_dict(self.model.state_dict())\n",
    "\n",
    "    def act(self, state: torch.Tensor, is_training: bool = True) -> dict:\n",
    "        \"\"\"Select action with optional exploration noise.\"\"\"\n",
    "        action_dict = self.model(state, is_training=is_training)\n",
    "        action = action_dict[\"action\"]\n",
    "        action_dict[\"action_wo_noise\"] = action\n",
    "        \n",
    "        if is_training:\n",
    "            noise = torch.normal(0, self.policy_noise, action.shape).to(self.device)\n",
    "            action += noise\n",
    "            action = torch.clip(action, self.action_limit_low, self.action_limit_high)\n",
    "            action_dict[\"action\"] = action\n",
    "        return action_dict\n",
    "\n",
    "    def act_target(self, state: torch.Tensor) -> dict:\n",
    "        \"\"\"Target policy with smoothing noise (TD3 style).\"\"\"\n",
    "        action_dict = self.target(state, is_training=False)\n",
    "        action = action_dict[\"action\"]\n",
    "        \n",
    "        # Add clipped noise for target smoothing\n",
    "        noise = torch.normal(0, self.target_policy_noise, action.shape).to(self.device)\n",
    "        noise = torch.clip(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip)\n",
    "        action += noise\n",
    "        action = torch.clip(action, self.action_limit_low, self.action_limit_high)\n",
    "        action_dict[\"action\"] = action\n",
    "        return action_dict\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def update_target(self) -> None:\n",
    "        \"\"\"Soft update of target network.\"\"\"\n",
    "        for target_param, local_param in zip(\n",
    "            self.target.parameters(), self.model.parameters()\n",
    "        ):\n",
    "            target_param.data.mul_(1.0 - self._tau)\n",
    "            target_param.data.add_(self._tau * local_param.data)\n",
    "    \n",
    "    def update(self, state: torch.Tensor, critics) -> None:\n",
    "        \"\"\"Update actor to maximize expected Q-value.\"\"\"\n",
    "        self.optim.zero_grad()\n",
    "        loss = self.loss(state, critics)\n",
    "        loss.backward()\n",
    "        self.optim.step()\n",
    "\n",
    "    def loss(self, state: torch.Tensor, critics) -> torch.Tensor:\n",
    "        \"\"\"Actor loss: negative expected Q-value.\"\"\"\n",
    "        batch_size = state.shape[0]\n",
    "        act_dict = self.act(state, is_training=False)\n",
    "        action = act_dict[\"action\"]\n",
    "\n",
    "        tau, tau_hat, presum_tau = critics.get_tau(batch_size=batch_size)\n",
    "        z_values, alpha_values, beta_values = critics.Q(state, action, tau_hat)\n",
    "        \n",
    "        # Compute expected Q-value by integrating over quantiles\n",
    "        q_values = torch.sum(z_values * presum_tau, dim=-1, keepdim=True)\n",
    "        q = torch.mean(q_values, dim=0)\n",
    "        \n",
    "        return -q.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18859d42",
   "metadata": {},
   "source": [
    "### Critic Base Classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40debc74",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Critic(nn.Module):\n",
    "    \"\"\"Single critic network with optional target network.\"\"\"\n",
    "    \n",
    "    def __init__(self, dim_state: int, dim_act: int) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.device = DEVICE\n",
    "        self.has_target = True\n",
    "        self._tau = TAU\n",
    "        self._gamma = GAMMA\n",
    "\n",
    "        self.model = QuantileCriticNet(dim_state, dim_act, width=WIDTH).to(self.device)\n",
    "\n",
    "        if self.has_target:\n",
    "            self.target = QuantileCriticNet(dim_state, dim_act, width=WIDTH).to(self.device)\n",
    "            self.init_target()\n",
    "\n",
    "    def reduce(self, q_val: torch.Tensor) -> torch.Tensor:\n",
    "        return q_val\n",
    "\n",
    "    def Q(self, state: torch.Tensor, action: torch.Tensor | None = None) -> torch.Tensor:\n",
    "        return self.model(self._prepare_input(state, action))\n",
    "\n",
    "    def forward(self, stateaction: torch.Tensor) -> torch.Tensor:\n",
    "        return self.model(stateaction)\n",
    "\n",
    "    @staticmethod\n",
    "    def _prepare_input(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:\n",
    "        if action.shape == ():\n",
    "            action = action.view(1, 1)\n",
    "        return torch.cat((state, action), -1)\n",
    "\n",
    "    def init_target(self) -> None:\n",
    "        assert self.has_target, \"There is no target network to initialize\"\n",
    "        for target_param, local_param in zip(\n",
    "            self.target.parameters(), self.model.parameters()\n",
    "        ):\n",
    "            target_param.data.copy_(local_param.data)\n",
    "\n",
    "    def update_target(self) -> None:\n",
    "        assert self.has_target, \"There is no target network to update\"\n",
    "        for target_param, local_param in zip(\n",
    "            self.target.parameters(), self.model.parameters()\n",
    "        ):\n",
    "            target_param.data.mul_(1.0 - self._tau)\n",
    "            target_param.data.add_(self._tau * local_param.data)\n",
    "\n",
    "    def Q_t(self, state: torch.Tensor, action: torch.Tensor | None = None) -> torch.Tensor:\n",
    "        assert self.has_target, \"There is no target network to evaluate\"\n",
    "        return self.target(self._prepare_input(state, action))\n",
    "\n",
    "    def __getitem__(self) -> \"Critic\":\n",
    "        return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dd5ce79",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CriticEnsemble(nn.Module, ABC):\n",
    "    \"\"\"\n",
    "    Ensemble of critics for reduced overestimation bias.\n",
    "    \n",
    "    Uses vectorized evaluation via torch.vmap for efficiency.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, dim_state: int, dim_act: int) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.n_members = 2\n",
    "        self.dim_state = dim_state\n",
    "        self.dim_act = dim_act\n",
    "        self.has_target = True\n",
    "        self.device = DEVICE\n",
    "        self._tau = TAU\n",
    "        self._gamma = GAMMA\n",
    "\n",
    "        # Create ensemble of critic networks\n",
    "        self.model_ensemble = Ensemble[Critic](\n",
    "            n_members=self.n_members,\n",
    "            prototype=Critic(dim_state, dim_act).model,\n",
    "            models=[Critic(dim_state, dim_act).model for _ in range(self.n_members)],\n",
    "            device=self.device,\n",
    "        )\n",
    "\n",
    "        self.optim = torch.optim.Adam(\n",
    "            [self.model_ensemble.params[key] for key in self.model_ensemble.params.keys()],\n",
    "            lr=LEARNING_RATE\n",
    "        )\n",
    "\n",
    "        if self.has_target:\n",
    "            self.target_ensemble = Ensemble[Critic](\n",
    "                n_members=self.n_members,\n",
    "                prototype=Critic(dim_state, dim_act).target,\n",
    "                models=[Critic(dim_state, dim_act).target for _ in range(self.n_members)],\n",
    "                device=self.device,\n",
    "            )\n",
    "            self.target_ensemble.load_state_dict(self.model_ensemble.state_dict())\n",
    "\n",
    "        self.iter = 0\n",
    "\n",
    "    def reset(self) -> None:\n",
    "        self.model_ensemble = Ensemble[Critic](\n",
    "            n_members=self.n_members,\n",
    "            prototype=Critic(self.config, self.dim_state, self.dim_act).model,\n",
    "            models=[Critic(self.config, self.dim_state, self.dim_act).model\n",
    "                    for _ in range(self.n_members)],\n",
    "            device=self.device,\n",
    "        )\n",
    "\n",
    "        if self.has_target:\n",
    "            self.target_ensemble = copy.deepcopy(self.model_ensemble)\n",
    "\n",
    "        self.optim = torch.optim.Adam(\n",
    "            [self.model_ensemble.params[key] for key in self.model_ensemble.params.keys()],\n",
    "            lr=LEARNING_RATE\n",
    "        )\n",
    "\n",
    "    def reduce(self, q_val: torch.Tensor, reduce_type: str) -> torch.Tensor:\n",
    "        if reduce_type == \"min\":\n",
    "            return q_val.min(0).values\n",
    "        elif reduce_type == \"mean\":\n",
    "            return q_val.mean(0)\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown reduction method: {reduce_type}\")\n",
    "\n",
    "    def _get_single_critic(self, index: int = 0) -> Critic:\n",
    "        single_critic = Critic(self.config, self.dim_state, self.dim_act)\n",
    "        single_critic.model.load_state_dict(self.model_ensemble[index].state_dict())\n",
    "        if self.has_target:\n",
    "            single_critic.target.load_state_dict(self.target_ensemble[index].state_dict())\n",
    "        return single_critic\n",
    "\n",
    "    def __getitem__(self, index: int) -> Critic:\n",
    "        return self._get_single_critic(index)\n",
    "\n",
    "    def Q(self, state: torch.Tensor, action: torch.Tensor | None = None) -> torch.Tensor:\n",
    "        if action is None:\n",
    "            sa = state\n",
    "        else:\n",
    "            sa = torch.cat((state, action), -1)\n",
    "        return self.model_ensemble(sa)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def Q_t(self, state: torch.Tensor, action: torch.Tensor | None = None) -> torch.Tensor:\n",
    "        assert self.has_target, \"There is no target network to evaluate\"\n",
    "        if action is None:\n",
    "            sa = state\n",
    "        else:\n",
    "            sa = torch.cat((state, action), -1)\n",
    "        return self.target_ensemble(sa)\n",
    "\n",
    "    def update(self, state: torch.Tensor, action: torch.Tensor, y: torch.Tensor) -> None:\n",
    "        self.optim.zero_grad()\n",
    "        loss = self.loss(self.Q(state, action), self.model_ensemble.expand(y))\n",
    "        loss = loss.sum(0).mean() if self.n_members > 1 else loss.mean()\n",
    "        loss.backward()\n",
    "        self.optim.step()\n",
    "        self.iter += 1\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def update_target(self) -> None:\n",
    "        assert self.has_target, \"There is no target network to update\"\n",
    "        for key in self.model_ensemble.params.keys():\n",
    "            self.target_ensemble.params[key].data.mul_(1.0 - self._tau)\n",
    "            self.target_ensemble.params[key].data.add_(\n",
    "                self._tau * self.model_ensemble.params[key].data\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43749719",
   "metadata": {},
   "source": [
    "### DAIF Critic\n",
    "\n",
    "The DAIF critic extends the ensemble with:\n",
    "- **Implicit quantile sampling**: Random quantile levels for each batch\n",
    "- **Bayesian regularization**: Priors on the distributional parameters\n",
    "- **Minimum across ensemble**: Overestimation mitigation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0aee7c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DAIFCritic(CriticEnsemble):\n",
    "    \"\"\"\n",
    "    DAIF Critic with distributional value estimation.\n",
    "    \n",
    "    Uses implicit quantile networks with Bayesian regularization priors\n",
    "    on the distributional parameters (alpha, beta).\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, dim_state: int, dim_act: int) -> None:\n",
    "        super().__init__(dim_state, dim_act)\n",
    "        self.num_quantiles = NUM_QUANTILES\n",
    "        self.reg_coeff = REGULARIZATION_COEFF\n",
    "\n",
    "        # Prior distributions for regularization\n",
    "        self.prior_mu = Normal(\n",
    "            torch.tensor(0.0, device=self.device),\n",
    "            torch.tensor(1000.0, device=self.device),\n",
    "        )\n",
    "\n",
    "        self.prior_alpha = TransformedDistribution(\n",
    "            Gamma(\n",
    "                torch.tensor(10.0, device=self.device),\n",
    "                torch.tensor(0.1, device=self.device),\n",
    "            ),\n",
    "            [torch.distributions.transforms.AffineTransform(loc=10.0, scale=1.0, cache_size=1)],\n",
    "        )\n",
    "\n",
    "        self.prior_beta = TransformedDistribution(\n",
    "            Gamma(\n",
    "                torch.tensor(10.0, device=self.device),\n",
    "                torch.tensor(0.1, device=self.device),\n",
    "            ),\n",
    "            [torch.distributions.transforms.AffineTransform(loc=10.0, scale=1.0, cache_size=1)],\n",
    "        )\n",
    "\n",
    "    def get_tau(self, batch_size):\n",
    "        \"\"\"\n",
    "        Sample random quantile levels for implicit quantile regression.\n",
    "        \n",
    "        Returns:\n",
    "            tau: Cumulative quantile levels\n",
    "            tau_hat: Midpoint quantile levels (for evaluation)\n",
    "            presum_tau: Quantile widths (for integration)\n",
    "        \"\"\"\n",
    "        presum_tau = torch.rand(batch_size, self.num_quantiles, device=self.device) + 0.1\n",
    "        presum_tau /= presum_tau.sum(dim=-1, keepdim=True)\n",
    "        tau = torch.cumsum(presum_tau, dim=1)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            tau_hat = torch.zeros_like(tau)\n",
    "            tau_hat[:, 0:1] = tau[:, 0:1] / 2.0\n",
    "            tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.0\n",
    "        \n",
    "        return tau, tau_hat, presum_tau\n",
    "\n",
    "    def Q(self, state: torch.Tensor, action: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"Compute distributional Q-values for given quantile levels.\"\"\"\n",
    "        sa = torch.cat((state, action), -1)\n",
    "\n",
    "        # Vectorize over the sample dimension\n",
    "        output_mu, output_alpha, output_beta = torch.vmap(\n",
    "            lambda a, b: self.model_ensemble((a, b)), in_dims=(None, 1), out_dims=2\n",
    "        )(sa, tau.unsqueeze(1))\n",
    "        \n",
    "        return (\n",
    "            output_mu.squeeze(2),\n",
    "            output_alpha.squeeze(2),\n",
    "            output_beta.squeeze(2),\n",
    "        )\n",
    "\n",
    "    def Q_t(self, state: torch.Tensor, action: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"Target network evaluation.\"\"\"\n",
    "        sa = torch.cat((state, action), -1)\n",
    "        \n",
    "        output_mu, output_alpha, output_beta = torch.vmap(\n",
    "            lambda a, b: self.target_ensemble((a, b)), in_dims=(None, 1), out_dims=2\n",
    "        )(sa, tau.unsqueeze(1))\n",
    "        \n",
    "        return (\n",
    "            output_mu.squeeze(2),\n",
    "            output_alpha.squeeze(2),\n",
    "            output_beta.squeeze(2),\n",
    "        )\n",
    "\n",
    "    def get_regularization(self, mu, alpha, beta):\n",
    "        \"\"\"Compute log-probability under prior distributions.\"\"\"\n",
    "        reg_mu = self.prior_mu.log_prob(mu).unsqueeze(-1)\n",
    "        reg_alpha = self.prior_alpha.log_prob(alpha).unsqueeze(-1)\n",
    "        reg_beta = self.prior_beta.log_prob(beta).unsqueeze(-1)\n",
    "        return reg_mu, reg_alpha, reg_beta\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def get_bellman_target(\n",
    "        self,\n",
    "        reward: torch.Tensor,\n",
    "        next_state: torch.Tensor,\n",
    "        done: torch.Tensor,\n",
    "        actor: DAIFActor,\n",
    "    ) -> torch.Tensor:\n",
    "        \"\"\"Compute distributional Bellman target with clipped double-Q.\"\"\"\n",
    "        batch_size = reward.shape[0]\n",
    "        act_dict = actor.act_target(next_state)\n",
    "        next_action = act_dict[\"action\"]\n",
    "\n",
    "        next_tau, next_tau_hat, next_presum_tau = self.get_tau(batch_size=batch_size)\n",
    "        mu, alpha, beta = self.Q_t(next_state, next_action, next_tau_hat)\n",
    "\n",
    "        # Take minimum across ensemble members (conservative estimate)\n",
    "        min_values, min_indices = torch.min(mu, dim=0, keepdim=True)\n",
    "        z_next_values = min_values.squeeze(0)\n",
    "\n",
    "        # Bellman equation\n",
    "        z_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * (self._gamma * z_next_values)\n",
    "\n",
    "        return z_target, next_presum_tau\n",
    "\n",
    "    def update(\n",
    "        self,\n",
    "        state: torch.Tensor,\n",
    "        action: torch.Tensor,\n",
    "        y: tuple[torch.Tensor, torch.Tensor],\n",
    "    ) -> None:\n",
    "        \"\"\"Update critic with DAIF loss + regularization.\"\"\"\n",
    "        self.optim.zero_grad()\n",
    "        \n",
    "        batch_size = state.shape[0]\n",
    "        tau, tau_hat, presum_tau = self.get_tau(batch_size=batch_size)\n",
    "        pred_mu, pred_alpha, pred_beta = self.Q(state, action, tau_hat)\n",
    "\n",
    "        y, target_tau = y\n",
    "        \n",
    "        # DAIF distributional loss\n",
    "        loss = daif_loss(\n",
    "            pred_mu, pred_alpha, pred_beta,\n",
    "            self.model_ensemble.expand(y),\n",
    "            tau_hat, target_tau,\n",
    "        )\n",
    "\n",
    "        # Bayesian regularization\n",
    "        reg_mu, reg_alpha, reg_beta = self.get_regularization(pred_mu, pred_alpha, pred_beta)\n",
    "        reg_loss = reg_mu + reg_alpha + reg_beta\n",
    "\n",
    "        loss = loss - self.reg_coeff * reg_loss\n",
    "        loss = loss.sum(-1).mean(axis=(1, 2)).sum(0)\n",
    "        loss.backward()\n",
    "        self.optim.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd79ab63",
   "metadata": {},
   "source": [
    "### DAIF Agent\n",
    "\n",
    "The complete DAIF agent combining actor and critic with delayed policy updates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "676258f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DAIF(nn.Module, ABC):\n",
    "    \"\"\"\n",
    "    Distributional Active Inference (DAIF).\n",
    "    \n",
    "    Combines:\n",
    "    - Distributional critic with implicit quantile networks\n",
    "    - TD3-style delayed policy updates and target smoothing\n",
    "    - Bayesian regularization on distributional parameters\n",
    "    \"\"\"\n",
    "    _agent_name = \"DAIF\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        env,\n",
    "        critic_type: type = DAIFCritic,\n",
    "        actor_type: type = DAIFActor,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        self.device = DEVICE\n",
    "        self.dim_state = env.observation_space.shape[0]\n",
    "        self.dim_act = env.action_space.shape[0]\n",
    "        self._gamma = GAMMA\n",
    "        self._tau = TAU\n",
    "\n",
    "        self.experience_memory = ReplayBuffer(DEVICE, DEVICE, BUFFER_SIZE)\n",
    "        self.critic = critic_type(self.dim_state, self.dim_act)\n",
    "        self.actor = actor_type(self.dim_state, self.dim_act)\n",
    "        self.policy_delay: int = POLICY_DELAY\n",
    "        self.n_iter: int = 0\n",
    "\n",
    "    def generate_transition(self, **kwargs):\n",
    "        \"\"\"Create a TensorDict transition for storage.\"\"\"\n",
    "        device = self.experience_memory.storing_device\n",
    "        transition = TensorDict(\n",
    "            {\n",
    "                \"state\": kwargs[\"state\"].to(device),\n",
    "                \"action\": kwargs[\"action\"].to(device),\n",
    "                \"reward\": float(kwargs[\"reward\"]),\n",
    "                \"next_state\": kwargs[\"next_state\"].to(device),\n",
    "                \"terminated\": float(kwargs[\"terminated\"]),\n",
    "                \"truncated\": float(kwargs[\"truncated\"]),\n",
    "                \"step\": kwargs[\"step\"],\n",
    "            }\n",
    "        )\n",
    "        return transition\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def select_action(self, state: torch.Tensor, is_training: bool = True) -> torch.Tensor:\n",
    "        \"\"\"Select action using the current policy.\"\"\"\n",
    "        act_dict = self.actor.act(state, is_training=is_training)\n",
    "        return act_dict\n",
    "\n",
    "    def store_transition(self, transition: tuple[Any, ...]) -> None:\n",
    "        \"\"\"Store transition in replay buffer.\"\"\"\n",
    "        self.experience_memory.add(transition)\n",
    "\n",
    "    def learn(self, max_iter: int = 1, n_epochs: int = 0) -> None:\n",
    "        \"\"\"Perform learning updates.\"\"\"\n",
    "        if BATCH_SIZE > len(self.experience_memory):\n",
    "            return None\n",
    "\n",
    "        n_steps = self.experience_memory.get_steps_and_iterator(n_epochs, max_iter, BATCH_SIZE)\n",
    "\n",
    "        for _ in range(n_steps):\n",
    "            batch = self.experience_memory.get_next_batch(BATCH_SIZE)\n",
    "\n",
    "            # Update critic\n",
    "            bellman_target = self.critic.get_bellman_target(\n",
    "                batch[\"reward\"], batch[\"next_state\"], batch[\"terminated\"], self.actor\n",
    "            )\n",
    "            self.critic.update(batch[\"state\"], batch[\"action\"], bellman_target)\n",
    "\n",
    "            # Delayed policy update (TD3 style)\n",
    "            if self.n_iter % self.policy_delay == 0:\n",
    "                self.actor.update(batch[\"state\"], self.critic)\n",
    "                self.actor.update_target()\n",
    "\n",
    "            self.critic.update_target()\n",
    "            self.n_iter += 1\n",
    "\n",
    "        return None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c78e8b61",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 8. Training Loop\n",
    "\n",
    "The experiment runner handles:\n",
    "- Environment interaction with warmup period\n",
    "- Periodic evaluation on a separate environment\n",
    "- Progress tracking and logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2095c715",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ControlExperiment(object):\n",
    "    \"\"\"\n",
    "    Training loop for continuous control experiments.\n",
    "    \n",
    "    Handles:\n",
    "    - Environment interaction with warmup period\n",
    "    - Periodic evaluation\n",
    "    - Progress tracking\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.n_total_steps: int = 0\n",
    "        self.max_steps: int = MAX_STEPS\n",
    "        self.warmup_steps: int = WARMUP_STEPS\n",
    "        self.device = DEVICE\n",
    "\n",
    "        # Separate environments for training and evaluation\n",
    "        self.env = make_env(robot_name=ROBOT_NAME, task_name=TASK, seed=SEED)\n",
    "        self.eval_env = make_env(robot_name=ROBOT_NAME, task_name=TASK, seed=SEED + 100)\n",
    "\n",
    "        self.agent = DAIF(self.env)\n",
    "\n",
    "    def train(self) -> None:\n",
    "        \"\"\"Main training loop.\"\"\"\n",
    "        time_start = time.time()\n",
    "\n",
    "        episode_rewards = []\n",
    "        episode_steps = []\n",
    "        eval_rewards = []\n",
    "\n",
    "        # Initialize environment\n",
    "        state, _ = self.env.reset()\n",
    "        state = totorch(state, device=self.device)\n",
    "        r_cum = np.zeros(1)\n",
    "        episode = 0\n",
    "        e_step = 0\n",
    "\n",
    "        # Training loop\n",
    "        for step in tqdm(range(self.max_steps), leave=True):\n",
    "            e_step += 1\n",
    "\n",
    "            # Periodic evaluation\n",
    "            if step % EVAL_FREQUENCY == 0:\n",
    "                avg_eval_reward = self.eval()\n",
    "                eval_rewards.append(avg_eval_reward)\n",
    "                tqdm.write(f\"EVAL| Step {step:7d} | Eval Reward: {avg_eval_reward:10.3f}\")\n",
    "\n",
    "            # Action selection: random during warmup, policy-based after\n",
    "            if step < self.warmup_steps:\n",
    "                action = self.env.action_space.sample()\n",
    "                action = totorch(np.clip(action, -1.0, 1.0), device=self.device)\n",
    "                act_dict = {\"action\": action}\n",
    "            else:\n",
    "                act_dict = self.agent.select_action(state)\n",
    "                action = act_dict[\"action\"].clip(-1.0, 1.0)\n",
    "\n",
    "            # Environment step\n",
    "            next_state, reward, terminated, truncated, info = self.env.step(tonumpy(action))\n",
    "            next_state = totorch(next_state, device=self.device)\n",
    "\n",
    "            # Store transition\n",
    "            transition_kwargs = {\n",
    "                **act_dict,\n",
    "                \"state\": state,\n",
    "                \"next_state\": next_state,\n",
    "                \"reward\": reward,\n",
    "                \"terminated\": terminated,\n",
    "                \"truncated\": truncated,\n",
    "                \"step\": step + 1,\n",
    "            }\n",
    "            transition = self.agent.generate_transition(**transition_kwargs)\n",
    "            self.agent.store_transition(transition)\n",
    "\n",
    "            state = next_state\n",
    "            r_cum += reward\n",
    "\n",
    "            # Learning updates after warmup\n",
    "            if step >= self.warmup_steps:\n",
    "                self.agent.learn(max_iter=1)\n",
    "\n",
    "            # Episode termination\n",
    "            if terminated or truncated:\n",
    "                episode_rewards.append(r_cum.item())\n",
    "                episode_steps.append(step)\n",
    "\n",
    "                state, _ = self.env.reset()\n",
    "                state = totorch(state, device=self.device)\n",
    "                r_cum = np.zeros(1)\n",
    "                episode += 1\n",
    "                e_step = 0\n",
    "\n",
    "        # Final evaluation\n",
    "        eval_rewards.append(self.eval())\n",
    "        \n",
    "        time_end = time.time()\n",
    "        print(f\"Training time: {time_end - time_start:.2f} seconds\")\n",
    "        \n",
    "        aulc = np.mean(eval_rewards)\n",
    "        final_reward = eval_rewards[-1]\n",
    "        print(f\"AULC: {aulc:.3f}, Final Eval Reward: {final_reward:.3f}\")\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def eval(self):\n",
    "        \"\"\"Evaluate current policy without exploration noise.\"\"\"\n",
    "        self.agent.eval()\n",
    "\n",
    "        # Save RNG states for reproducibility\n",
    "        torch_rng_state = torch.get_rng_state()\n",
    "        if torch.cuda.is_available():\n",
    "            cuda_rng_state = torch.cuda.get_rng_state_all()\n",
    "\n",
    "        # Set deterministic seed for evaluation\n",
    "        eval_seed = SEED + 12345\n",
    "        torch.manual_seed(eval_seed)\n",
    "        if torch.cuda.is_available():\n",
    "            torch.cuda.manual_seed(eval_seed)\n",
    "            torch.cuda.manual_seed_all(eval_seed)\n",
    "\n",
    "        results = torch.zeros(EVAL_EPISODES)\n",
    "\n",
    "        for episode in range(EVAL_EPISODES):\n",
    "            s, info = self.eval_env.reset()\n",
    "            s = totorch(s, device=self.device)\n",
    "            step = 0\n",
    "            done = False\n",
    "\n",
    "            while not done:\n",
    "                a = self.agent.select_action(s, is_training=False)[\"action\"]\n",
    "                sp, r, term, trunc, info = self.eval_env.step(tonumpy(a))\n",
    "                done = term or trunc\n",
    "                s = totorch(sp, device=self.device)\n",
    "                results[episode] += r\n",
    "                step += 1\n",
    "\n",
    "        self.agent.train()\n",
    "\n",
    "        # Restore RNG states\n",
    "        torch.set_rng_state(torch_rng_state)\n",
    "        if torch.cuda.is_available():\n",
    "            torch.cuda.set_rng_state_all(cuda_rng_state)\n",
    "\n",
    "        return results.mean().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bed480e",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Run Training\n",
    "\n",
    "Execute the cell below to start training. The agent will:\n",
    "\n",
    "1. Collect random transitions during warmup (10,000 steps)\n",
    "2. Train for 1,000,000 total environment steps\n",
    "3. Evaluate every 20,000 steps (10 episodes per evaluation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4067fd3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize and run the experiment\n",
    "experimenter = ControlExperiment()\n",
    "experimenter.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bdfe652",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "daif",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
