{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Import Necessary Packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GUrtasdlwqId"
      },
      "outputs": [],
      "source": [
        "# Standard Library\n",
        "import argparse\n",
        "import itertools\n",
        "import math\n",
        "import os\n",
        "import random\n",
        "import time\n",
        "from collections import deque\n",
        "from typing import Dict, Optional, OrderedDict, Tuple\n",
        "\n",
        "# Third-Party Libraries\n",
        "import gymnasium as gym\n",
        "from gymnasium import core, spaces\n",
        "from gymnasium.spaces import Box, Dict\n",
        "from gymnasium.wrappers import RescaleAction\n",
        "from dm_control import suite\n",
        "from scipy.stats import norm\n",
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "\n",
        "# PyTorch\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.nn.init as init\n",
        "from torch import nn as torch_nn  \n",
        "from torch import func as thf    \n",
        "from torch.distributions import Normal, TransformedDistribution\n",
        "from torch.distributions.transforms import TanhTransform\n",
        "from torch.nn.parameter import Parameter\n",
        "from torch.nn.modules.utils import _pair\n",
        "\n",
        "# Aliases\n",
        "import torch as th \n",
        "import gym  \n",
        "import dm_env  \n",
        "\n",
        "# Argument Parser for Jupyter compatibility\n",
        "parser = argparse.ArgumentParser()\n",
        "parser.add_argument(\"-f\", required=False)  \n",
        "args, unknown = parser.parse_known_args()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rYywNtri14tV"
      },
      "source": [
        "# Get the Model "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "uG5PZt1Y13fT"
      },
      "outputs": [],
      "source": [
        "def get_model( env):\n",
        "    return RandomEnsembleDoubleQLearning (env, args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZci7w8L3v8J"
      },
      "source": [
        "# DMControl Environment\n",
        " Define a wrapper that converts DeepMind Control Suite (DMC) environments into a Gym-compatible format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ndagaTFC3xf_",
        "outputId": "1e34eac7-de16-4151-aa33-2f10a0907165"
      },
      "outputs": [],
      "source": [
        "\n",
        "TimeStep = Tuple[np.ndarray, float, bool, bool, dict]\n",
        "\n",
        "def dmc_spec2gym_space(spec):\n",
        "    if isinstance(spec, OrderedDict) or isinstance(spec, dict):\n",
        "        spec = copy.copy(spec)\n",
        "        for k, v in spec.items():\n",
        "            spec[k] = dmc_spec2gym_space(v)\n",
        "        return spaces.Dict(spec)\n",
        "    elif isinstance(spec, dm_env.specs.BoundedArray):\n",
        "        return spaces.Box(low=spec.minimum,\n",
        "                          high=spec.maximum,\n",
        "                          shape=spec.shape,\n",
        "                          dtype=spec.dtype)\n",
        "    elif isinstance(spec, dm_env.specs.Array):\n",
        "        return spaces.Box(low=-float('inf'),\n",
        "                          high=float('inf'),\n",
        "                          shape=spec.shape,\n",
        "                          dtype=spec.dtype)\n",
        "    else:\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "class DMCEnv(core.Env):\n",
        "    def __init__(self,\n",
        "                 domain_name: Optional[str] = None,\n",
        "                 task_name: Optional[str] = None,\n",
        "                 env: Optional[dm_env.Environment] = None,\n",
        "                 task_kwargs: Optional[Dict] = {},\n",
        "                 environment_kwargs=None):\n",
        "        assert 'random' in task_kwargs, 'Please specify a seed, for deterministic behaviour.'\n",
        "        assert (\n",
        "            env is not None\n",
        "            or (domain_name is not None and task_name is not None)\n",
        "        ), 'You must provide either an environment or domain and task names.'\n",
        "\n",
        "        if env is None:\n",
        "            env = suite.load(\n",
        "                domain_name=domain_name,\n",
        "                task_name=task_name,\n",
        "                task_kwargs=task_kwargs,\n",
        "                environment_kwargs=environment_kwargs,\n",
        "                visualize_reward=True\n",
        "            )\n",
        "\n",
        "        self._env = env\n",
        "        self.domain_name = domain_name\n",
        "        self.task_name = task_name\n",
        "        self.action_space = dmc_spec2gym_space(self._env.action_spec())\n",
        "\n",
        "        self.observation_space = dmc_spec2gym_space(\n",
        "            self._env.observation_spec())\n",
        "\n",
        "    def __getattr__(self, name):\n",
        "        return getattr(self._env, name)\n",
        "\n",
        "    def step(self, action: np.ndarray) -> TimeStep:\n",
        "        assert self.action_space.contains(action)\n",
        "\n",
        "        time_step = self._env.step(action)\n",
        "        reward = time_step.reward or 0\n",
        "        done = time_step.last()\n",
        "        obs = time_step.observation\n",
        "\n",
        "        info  = {}\n",
        "        trunc = done and (time_step.discount == 1.0)\n",
        "        term = done and (time_step.discount != 1.0)\n",
        "        if trunc:\n",
        "            info['TimeLimit.truncated'] = True\n",
        "        return obs, reward, term, trunc, info\n",
        "\n",
        "    def reset(self, seed=None, options=None):\n",
        "        super().reset(seed=seed)\n",
        "        time_step = self._env.reset()\n",
        "        info = {}\n",
        "        return time_step.observation, info\n",
        "\n",
        "    def render(self,\n",
        "               mode='rgb_array',\n",
        "               height: int = 84,\n",
        "               width: int = 84,\n",
        "               camera_id: int = 0):\n",
        "        assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode\n",
        "        return self._env.physics.render(height=height,\n",
        "                                        width=width,)\n",
        "       "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Make Environment\n",
        "Defines wrappers and utilities to preprocess Gym environments. The `make_env()` function sets up a DeepMind Control Suite environment with various preprocessing options.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "vWY4CKP0171s"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "class SinglePrecision(gym.ObservationWrapper):\n",
        "    def __init__(self, env):\n",
        "        super().__init__(env)\n",
        "\n",
        "        if isinstance(self.observation_space, Box):\n",
        "            obs_space = self.observation_space\n",
        "            self.observation_space = Box(obs_space.low, obs_space.high,\n",
        "                                         obs_space.shape)\n",
        "        elif isinstance(self.observation_space, Dict):\n",
        "            obs_spaces = copy.copy(self.observation_space.spaces)\n",
        "            for k, v in obs_spaces.items():\n",
        "                obs_spaces[k] = Box(v.low, v.high, v.shape)\n",
        "            self.observation_space = Dict(obs_spaces)\n",
        "        else:\n",
        "            raise NotImplementedError\n",
        "\n",
        "    def observation(self, observation: np.ndarray) -> np.ndarray:\n",
        "        if isinstance(observation, np.ndarray):\n",
        "            return observation.astype(np.float32)\n",
        "        elif isinstance(observation, dict):\n",
        "            observation = copy.copy(observation)\n",
        "            for k, v in observation.items():\n",
        "                observation[k] = v.astype(np.float32)\n",
        "            return observation\n",
        "        \n",
        "class FlattenAction(gym.ActionWrapper):\n",
        "    \"\"\"Action wrapper that flattens the action.\"\"\"\n",
        "\n",
        "    def __init__(self, env):\n",
        "        super(FlattenAction, self).__init__(env)\n",
        "        self.action_space = gym.spaces.utils.flatten_space(self.env.action_space)\n",
        "\n",
        "    def action(self, action):\n",
        "        return gym.spaces.utils.unflatten(self.env.action_space, action)\n",
        "\n",
        "    def reverse_action(self, action):\n",
        "        return gym.spaces.utils.flatten(self.env.action_space, action)\n",
        "\n",
        "def make_env(env_name: str,\n",
        "             seed: int,\n",
        "             save_folder: Optional[str] = None,\n",
        "             add_episode_monitor: bool = True,\n",
        "             action_repeat: int = 1,\n",
        "             frame_stack: int = 1,\n",
        "             from_pixels: bool = False,\n",
        "             pixels_only: bool = True,\n",
        "             image_size: int = 84,\n",
        "             sticky: bool = False,\n",
        "             gray_scale: bool = False,\n",
        "             flatten: bool = True,\n",
        "             terminate_when_unhealthy: bool = True,\n",
        "             action_concat: int = 1,\n",
        "             obs_concat: int = 1,\n",
        "             continuous: bool = True,\n",
        "             ) -> gym.Env:\n",
        "\n",
        "    env_ids = list(gym.envs.registry.keys())\n",
        "\n",
        "    \n",
        "    if env_name in env_ids:\n",
        "        env = gym.make(env_name)\n",
        "        save_folder = None\n",
        "    else:\n",
        "        domain_name, task_name = env_name.split('-')\n",
        "        env = DMCEnv(domain_name=domain_name, task_name=task_name, task_kwargs={'random': seed})\n",
        "\n",
        "    if flatten and isinstance(env.observation_space, gym.spaces.Dict):\n",
        "        env = gym.wrappers.FlattenObservation(env)\n",
        "        env = FlattenAction(env)\n",
        "\n",
        "    if continuous:\n",
        "        env = RescaleAction(env, -1.0, 1.0)\n",
        "\n",
        "    env = SinglePrecision(env)\n",
        "    env.reset(seed=seed)\n",
        "    env.action_space.seed(seed)\n",
        "    env.observation_space.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    return env\n",
        "\n",
        "class Experiment(object):\n",
        "    def __init__(self):\n",
        "        self.args = args\n",
        "        self.n_total_steps = 0\n",
        "        self.max_steps = 100000\n",
        "        # self.env = make_env('Ant-v4', 1)\n",
        "        # self.eval_env = make_env('Ant-v4', 101)\n",
        "        self.env = make_env('cartpole-swingup', 1)\n",
        "        self.eval_env = make_env('cartpole-swingup', 101)\n",
        "        self.agent = get_model( self.env)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VmIkEsMzyPdc"
      },
      "source": [
        "# Architectures \n",
        "This cell includes all the network design architectures, such as the critic and actor networks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zfsulbl1yO81"
      },
      "outputs": [],
      "source": [
        "\n",
        "def tonumpy(x):\n",
        "    return x.data.cpu().numpy() \n",
        "\n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super(Critic, self).__init__()\n",
        "        self.args = args\n",
        "        self.args.device =\"cpu\"\n",
        "        self.model = arch(n_state, n_action, args.n_hidden).to(self.args.device)\n",
        "        self.target = arch(n_state, n_action, args.n_hidden).to(self.args.device)\n",
        "        self.init_target()\n",
        "        self.loss = nn.MSELoss()\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)\n",
        "        self.iter = 0\n",
        "        self.args.tau = 0.005\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    @th.no_grad()\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        return self.model(s, a)\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        return self.target(s, a)\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), y)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "\n",
        "class CriticEnsemble(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=Critic):\n",
        "        super(CriticEnsemble, self).__init__()\n",
        "        self.n_elements = self.args.n_critics\n",
        "        print(f\"Number of elements: {self.n_elements}\")\n",
        "        self.args = args\n",
        "        self.critics = [\n",
        "            critictype(arch, args, n_state, n_action) for _ in range(self.n_elements)\n",
        "        ]\n",
        "        self.gamma=0.99\n",
        "        self.iter = 0\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "        return self.critics[item]\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "        [critic.set_writer(writer) for critic in self.critics]\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        return [critic.Q(s, a) for critic in self.critics]\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        return [critic.Q_t(s, a) for critic in self.critics]\n",
        "\n",
        "    def update(self, s, a, y):\n",
        "        [critic.update(s, a, y) for critic in self.critics]\n",
        "        self.iter += 1\n",
        "\n",
        "    def update_target(self):\n",
        "        [critic.update_target() for critic in self.critics]\n",
        "\n",
        "    def reduce(self, q_val_list):\n",
        "        return torch.stack(q_val_list, dim=-1).min(dim=-1)[0]\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_bellman_target(self, r, sp, done, actor):\n",
        "        alpha = actor.log_alpha.exp().detach() if hasattr(actor, \"log_alpha\") else 0\n",
        "        ap, ep = actor.act(sp)\n",
        "        qp = self.Q_t(sp, ap)\n",
        "        if ep is None:\n",
        "            ep = 0\n",
        "        qp_t = self.reduce(qp) - alpha * ep\n",
        "        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))\n",
        "        return y\n",
        " \n",
        "\n",
        "class ParallelCritic(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super(ParallelCritic, self).__init__()\n",
        "        self.args = args\n",
        "        self.arch = arch\n",
        "        args.device = \"cpu\"\n",
        "        args.learning_rate=3e-4\n",
        "        self.model = arch(\n",
        "            n_state,\n",
        "            n_action,\n",
        "            depth=3,\n",
        "            width=256,\n",
        "            act=\"crelu\",\n",
        "            has_norm=not False,\n",
        "        ).to(args.device)\n",
        "        self.target = arch(\n",
        "            n_state,\n",
        "            n_action,\n",
        "            depth=3,\n",
        "            width=256,\n",
        "            act=\"crelu\",\n",
        "            has_norm=not False\n",
        "        ).to(args.device)\n",
        "        self.init_target()\n",
        "        self.loss = nn.HuberLoss()\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)\n",
        "        self.iter = 0\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        if a.shape == ():\n",
        "            a = a.view(1, 1)\n",
        "        return self.model(th.cat((s, a), -1))\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        if a.shape == ():\n",
        "            a = a.view(1, 1)\n",
        "        return self.target(th.cat((s, a), -1))\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), y)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "\n",
        "class ParallelCritics(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=ParallelCritic):\n",
        "        super(ParallelCritics, self).__init__()\n",
        "        self.n_members = 10\n",
        "        self.args = args\n",
        "        self.args.verbose = False\n",
        "        self.arch = arch\n",
        "        self.n_state = n_state\n",
        "        self.n_action = n_action\n",
        "        self.critictype = critictype\n",
        "        self.iter = 0\n",
        "        self.args.tau = 0.005\n",
        "        self.loss = self.critictype(\n",
        "            self.arch, self.args, self.n_state, self.n_action\n",
        "        ).loss\n",
        "        self.optim = self.critictype(\n",
        "            self.arch, self.args, self.n_state, self.n_action\n",
        "        ).optim\n",
        "\n",
        "        # Helperfunctions\n",
        "        self.expand = lambda x: (\n",
        "            x.expand(self.n_members, *x.shape) if len(x.shape) < 3 else x\n",
        "        )\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self):\n",
        "        self.critics = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action)\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "\n",
        "        self.critics_model = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action).model\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "        self.critics_target = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action).target\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "\n",
        "        self.params_model, self.buffers_model = thf.stack_module_state(\n",
        "            self.critics_model\n",
        "        )\n",
        "        self.params_target, self.buffers_target = thf.stack_module_state(\n",
        "            self.critics_target\n",
        "        )\n",
        "\n",
        "        self.base_model = copy.deepcopy(self.critics[0].model).to(\"meta\")\n",
        "        self.base_target = copy.deepcopy(self.critics[0].target).to(\"meta\")\n",
        "\n",
        "        def _fmodel(base_model, params, buffers, x):\n",
        "            return thf.functional_call(base_model, (params, buffers), (x,))\n",
        "\n",
        "        self.forward_model = thf.vmap(lambda p, b, x: _fmodel(self.base_model, p, b, x))\n",
        "        self.forward_target = thf.vmap(\n",
        "            lambda p, b, x: _fmodel(self.base_target, p, b, x)\n",
        "        )\n",
        "        self.optim = th.optim.Adam(\n",
        "            self.params_model.values(), lr=self.args.learning_rate\n",
        "        )\n",
        "\n",
        "    def reduce(self, q_val):\n",
        "        return q_val.min(0)[0]\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "        return self.critics[item]\n",
        "\n",
        "    def unstack(self, target=False, single=True, net_id=None):\n",
        "        \"\"\"\n",
        "        Extract the single parameters back to the individual members\n",
        "        target: whether the target ensemble should be extracted or not\n",
        "        single: whether just the first member of the ensemble should be extracted\n",
        "        \"\"\"\n",
        "        params = self.params_target if target else self.params_model\n",
        "        if single and net_id is None:\n",
        "            net_id = 0\n",
        "\n",
        "        for key in params.keys():\n",
        "            if single:\n",
        "                tmp = (\n",
        "                    self.critics[net_id].model\n",
        "                    if not target\n",
        "                    else self.critics[net_id].target\n",
        "                )\n",
        "                for name in key.split(\".\"):\n",
        "                    tmp = getattr(tmp, name)\n",
        "                tmp.data.copy_(params[key][net_id])\n",
        "            else:\n",
        "                for net_id in range(self.n_members):\n",
        "                    tmp = (\n",
        "                        self.critics[net_id].model\n",
        "                        if not target\n",
        "                        else self.critics[net_id].target\n",
        "                    )\n",
        "                    for name in key.split(\".\"):\n",
        "                        tmp = getattr(tmp, name)\n",
        "                    tmp.data.copy_(params[key][net_id])\n",
        "                    if single:\n",
        "                        break\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        assert (\n",
        "            writer is None\n",
        "        ), \"For now nothing else is implemented for the parallel version\"\n",
        "        self.writer = writer\n",
        "        [critic.set_writer(writer) for critic in self.critics]\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        if len(a.shape) == 1:\n",
        "            a = a.view(-1,1)\n",
        "        SA = self.expand(th.cat((s, a), -1))\n",
        "        return self.forward_model(self.params_model, self.buffers_model, SA)\n",
        "\n",
        "    @th.no_grad()\n",
        "    def Q_t(self, s, a):\n",
        "        SA = self.expand(th.cat((s, a), -1))\n",
        "        return self.forward_target(self.params_target, self.buffers_target, SA)\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), self.expand(y))\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def update_target(self):\n",
        "        for key in self.params_model.keys():\n",
        "            self.params_target[key].data.mul_(1.0 - self.args.tau)\n",
        "            self.params_target[key].data.add_(\n",
        "                self.args.tau * self.params_model[key].data\n",
        "            )\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_bellman_target(self, r, sp, done, actor):\n",
        "        alpha = actor.log_alpha.exp().detach() if hasattr(actor, \"log_alpha\") else 0\n",
        "        ap, ep = actor.act(sp)\n",
        "        qp = self.Q_t(sp, ap)\n",
        "        qp_t = self.reduce(qp) - alpha * (ep if ep is not None else 0)\n",
        "        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))\n",
        "        tqdm.write(f\"{y = }\")\n",
        "        return y\n",
        "\n",
        "\n",
        "class Actor(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, has_target=False):\n",
        "        super().__init__()\n",
        "        self.model = arch(\n",
        "            n_state,\n",
        "            n_action,\n",
        "            depth=3,\n",
        "            width=256,\n",
        "            act=\"crelu\",\n",
        "            has_norm=not False,\n",
        "        )\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)\n",
        "        self.args = args\n",
        "        self.has_target = has_target\n",
        "        self.args.verbose = False\n",
        "        self.iter = 0\n",
        "        self.is_episode_end = False\n",
        "        self.states = []\n",
        "        self.print_freq = 500\n",
        "\n",
        "        if has_target:\n",
        "            self.target = arch(\n",
        "                n_state,\n",
        "                n_action,\n",
        "                depth=3,\n",
        "                width=256,\n",
        "                act=\"crelu\",\n",
        "                has_norm=not False,\n",
        "            )\n",
        "            self.init_target()\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def act(self, s, is_training=True):\n",
        "        a, e = self.model(\n",
        "            s, is_training=is_training\n",
        "        ) \n",
        "\n",
        "        if is_training:\n",
        "            if self.args.verbose and self.iter % self.print_freq == 0:\n",
        "                self.states.append(tonumpy(s))\n",
        "        return a, e\n",
        "    \n",
        "    def act_target(self, s):\n",
        "        a, e = self.target(s)\n",
        "        return a, e\n",
        "\n",
        "    def set_episode_status(self, is_end):\n",
        "        self.is_episode_end = is_end\n",
        "\n",
        "    @th.no_grad()\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.tau)\n",
        "            target_param.data.add_(self.tau * local_param.data)\n",
        "\n",
        "    def loss(self, s, critics):\n",
        "        a, _ = self.act(s)\n",
        "        q_list = critics.Q(s, a)\n",
        "        q = critics.reduce(q_list)\n",
        "        return (-q).mean(), None\n",
        "\n",
        "    def update(self, s, critics):\n",
        "        self.optim.zero_grad()\n",
        "        loss, _ = self.loss(s, critics)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "\n",
        "        if self.has_target:\n",
        "            self.update_target()\n",
        "\n",
        "        self.iter += 1\n",
        "    \n",
        "    def save_actor_params(self, path):\n",
        "        params = {\n",
        "            \"params_model\": self.model.state_dict(),\n",
        "        }\n",
        "\n",
        "        params_th = {\n",
        "            k: v if isinstance(v, torch.Tensor) else v  # Ensure the values are tensors\n",
        "            for k, v in params.items()\n",
        "        }\n",
        "\n",
        "        torch.save(params_th, path)\n",
        "\n",
        "\n",
        "class SoftActor(Actor):\n",
        "    def __init__(self, arch, args, n_state, n_action, has_target=False):\n",
        "        super(SoftActor, self).__init__(arch, args, n_state, n_action, has_target)\n",
        "        self.H_target = -n_action[0]\n",
        "        args.learning_rate=3e-4\n",
        "        args.alpha = 1\n",
        "        self.device = \"cpu\"\n",
        "        self.log_alpha = torch.tensor(\n",
        "            math.log(args.alpha), requires_grad=True , device=self.device\n",
        "        )\n",
        "        self.optim_alpha = torch.optim.Adam([self.log_alpha], args.learning_rate)\n",
        "\n",
        "    def update_alpha(self, e):\n",
        "        self.optim_alpha.zero_grad()\n",
        "        alpha_loss = -(self.log_alpha.exp() * (e + self.H_target).detach()).mean()\n",
        "        alpha_loss.backward()\n",
        "        self.optim_alpha.step()\n",
        "\n",
        "    def loss(self, s, critics):\n",
        "        a, e = self.act(s)\n",
        "        q_list = critics.Q(s, a)\n",
        "        q = critics.reduce(q_list)\n",
        "        return (-q + self.log_alpha.exp() * e).mean(), e\n",
        "\n",
        "    def update(self, s, critics):\n",
        "        self.optim.zero_grad()\n",
        "        loss, e = self.loss(s, critics)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.update_alpha(e)\n",
        "        self.iter += 1\n",
        "\n",
        "\n",
        "class CReLU(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = torch.cat((x, -x), -1)\n",
        "        return F.relu(x)\n",
        "\n",
        "\n",
        "def create_net(d_in, d_out, depth, width, act=\"crelu\", has_norm=True, n_elements=1):\n",
        "    assert depth > 0, \"Need at least one layer\"\n",
        "\n",
        "    double_width = False\n",
        "    if act == \"crelu\":\n",
        "        act = CReLU\n",
        "        double_width = True\n",
        "    elif act == \"relu\":\n",
        "        act = nn.ReLU\n",
        "    else:\n",
        "        raise NotImplementedError(f\"{act} is not implemented\")\n",
        "\n",
        "    if depth == 1:\n",
        "        arch = nn.Linear(d_in, d_out)\n",
        "    elif depth == 2:\n",
        "        arch = nn.Sequential(\n",
        "            nn.Linear(d_in, width),\n",
        "            (\n",
        "                nn.LayerNorm(width, elementwise_affine=False)\n",
        "                if has_norm\n",
        "                else nn.Identity()\n",
        "            ),\n",
        "            act(),\n",
        "            nn.Linear(2 * width if double_width else width, d_out),\n",
        "        )\n",
        "    else:\n",
        "        in_layer = nn.Linear(d_in, width)\n",
        "        if n_elements > 1:\n",
        "            out_layer = nn.Linear(2 * width if double_width else width, d_out, n_elements)\n",
        "        else:\n",
        "            out_layer = nn.Linear(2 * width if double_width else width, d_out)\n",
        "\n",
        "        hidden = list(\n",
        "            itertools.chain.from_iterable(\n",
        "                [\n",
        "                    [\n",
        "                        (\n",
        "                            nn.LayerNorm(width, elementwise_affine=False)\n",
        "                            if has_norm\n",
        "                            else nn.Identity()\n",
        "                        ),\n",
        "                        act(),\n",
        "                        nn.Linear(2 * width if double_width else width, width),\n",
        "                    ]\n",
        "                    for _ in range(depth - 1)\n",
        "                ]\n",
        "            )\n",
        "        )[:-1]\n",
        "        arch = nn.Sequential(in_layer, *hidden, out_layer)\n",
        "\n",
        "    return arch\n",
        "\n",
        "\n",
        "\n",
        "class SquashedGaussianHead(nn.Module):\n",
        "    def __init__(self, n, upper_clamp=-2.0):\n",
        "        super().__init__()\n",
        "        self._n = n\n",
        "        self._upper_clamp = upper_clamp\n",
        "\n",
        "    def forward(self, x, is_training=True):\n",
        "        # bt means before tanh\n",
        "        mean_bt = x[..., : self._n]\n",
        "        log_var_bt = (x[..., self._n :]).clamp(-10, -self._upper_clamp)  # clamp added\n",
        "        std_bt = log_var_bt.exp().sqrt()\n",
        "        dist_bt = Normal(mean_bt, std_bt)\n",
        "        transform = TanhTransform(cache_size=1)\n",
        "        dist = TransformedDistribution(dist_bt, transform)\n",
        "        if is_training:\n",
        "            y = dist.rsample()\n",
        "            y_logprob = dist.log_prob(y).sum(dim=-1, keepdim=True)\n",
        "        else:\n",
        "            y_samples = dist.rsample((100,))\n",
        "            y = y_samples.mean(dim=0)\n",
        "            y_logprob = None\n",
        "\n",
        "        return y, y_logprob  # dist\n",
        "       \n",
        "    \n",
        "class ActorNet(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim_obs,\n",
        "        dim_act,\n",
        "        depth=3,\n",
        "        width=256,\n",
        "        act=\"crelu\",\n",
        "        has_norm=True,\n",
        "        upper_clamp=None,\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0], dim_act[0], depth, width, act, has_norm\n",
        "        ).append(nn.Tanh())\n",
        "\n",
        "    def forward(self, x, is_training=None):\n",
        "        out = self.arch(x).clamp(-0.9999, 0.9999)\n",
        "        return out, None\n",
        "\n",
        "\n",
        "class ActorNetEnsemble(ActorNet):\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim_obs,\n",
        "        dim_act,\n",
        "        depth=3,\n",
        "        width=256,\n",
        "        act=\"crelu\",\n",
        "        has_norm=True,\n",
        "        upper_clamp=None,\n",
        "        n_elements=10\n",
        "    ):\n",
        "        super(ActorNetEnsemble, self).__init__(dim_obs, dim_act, depth, width, act, has_norm, upper_clamp)\n",
        "\n",
        "        self.dim_act = dim_act\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0], dim_act[0]*n_elements, depth, width, act, has_norm\n",
        "        ).append(nn.Tanh())\n",
        "        self.n_elements = n_elements\n",
        "\n",
        "    def forward(self, x, is_training=None):\n",
        "        out = self.arch(x).clamp(-0.9999, 0.9999)\n",
        "        out = out.view(-1, self.n_elements, self.dim_act[0])\n",
        "        return out, None\n",
        "    \n",
        "\n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super().__init__()\n",
        "        self.args = args\n",
        "        self.arch = arch\n",
        "        self.args.depth_critic = 3\n",
        "        self.args.width_critic = 256\n",
        "        self.args.act_critic = \"crelu\"\n",
        "        self.args.no_norm_critic = False\n",
        "        self.args.device = \"cpu\"\n",
        "        self.args.learning_rate = 3e-4\n",
        "        self.model = arch(\n",
        "            n_state,\n",
        "            n_action,\n",
        "            depth=self.args.depth_critic,\n",
        "            width=self.args.width_critic,\n",
        "            act=self.args.act_critic,\n",
        "            has_norm=not self.args.no_norm_critic,\n",
        "        ).to(self.args.device)\n",
        "        self.target = arch(\n",
        "            n_state,\n",
        "            n_action,\n",
        "            depth=self.args.depth_critic,\n",
        "            width=self.args.width_critic,\n",
        "            act=self.args.act_critic,\n",
        "            has_norm=not self.args.no_norm_critic,\n",
        "        ).to(self.args.device)\n",
        "        self.init_target()\n",
        "        # self.loss = nn.MSELoss()\n",
        "        self.loss = nn.HuberLoss()\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), self.args.learning_rate)\n",
        "        self.iter = 0\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        if a.shape == ():\n",
        "            a = a.view(1, 1)\n",
        "        return self.model(th.cat((s, a), -1))\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        if a.shape == ():\n",
        "            a = a.view(1, 1)\n",
        "        return self.target(th.cat((s, a), -1))\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), y)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "\n",
        "class Critics(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=Critic):\n",
        "        super().__init__()\n",
        "        self.args = args\n",
        "        self.args.gamma = 0.99\n",
        "        self.args.tau = 0.005\n",
        "        self.args.verbose = False\n",
        "        self.args.buffer_size = 100000\n",
        "        self.args.learning_rate = 3e-4\n",
        "        self.args.depth_critic = 3\n",
        "        self.args.width_critic = 256\n",
        "        self.args.act_critic = \"crelu\"\n",
        "        self.args.no_norm_critic = False\n",
        "        self.args.device = \"cpu\"\n",
        "        self.args.learning_rate = 3e-4\n",
        "        self.n_members = 10\n",
        "        self.arch = arch\n",
        "        self.n_state = n_state\n",
        "        self.n_action = n_action\n",
        "        self.critictype = critictype\n",
        "        self.iter = 0\n",
        "        # self.loss = nn.MSELoss()\n",
        "        self.loss = self.critictype(\n",
        "            self.arch, self.args, self.n_state, self.n_action\n",
        "        ).loss\n",
        "        self.optim = self.critictype(\n",
        "            self.arch, self.args, self.n_state, self.n_action\n",
        "        ).optim\n",
        "\n",
        "        # Helperfunctions\n",
        "        self.expand = lambda x: (\n",
        "            x.expand(self.n_members, *x.shape) if len(x.shape) < 3 else x\n",
        "        )\n",
        "        # self.reduce = lambda q_val: q_val.min(0)[0]\n",
        "\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self):\n",
        "        self.critics = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action)\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "\n",
        "        self.critics_model = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action).model\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "        self.critics_target = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action).target\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "\n",
        "        self.params_model, self.buffers_model = thf.stack_module_state(\n",
        "            self.critics_model\n",
        "        )\n",
        "        self.params_target, self.buffers_target = thf.stack_module_state(\n",
        "            self.critics_target\n",
        "        )\n",
        "\n",
        "        self.base_model = copy.deepcopy(self.critics[0].model).to(\"meta\")\n",
        "        self.base_target = copy.deepcopy(self.critics[0].target).to(\"meta\")\n",
        "\n",
        "        def _fmodel(base_model, params, buffers, x):\n",
        "            return thf.functional_call(base_model, (params, buffers), (x,))\n",
        "\n",
        "        self.forward_model = thf.vmap(lambda p, b, x: _fmodel(self.base_model, p, b, x))\n",
        "        self.forward_target = thf.vmap(\n",
        "            lambda p, b, x: _fmodel(self.base_target, p, b, x)\n",
        "        )\n",
        "        self.optim = th.optim.Adam(\n",
        "            self.params_model.values(), lr=self.args.learning_rate\n",
        "        )\n",
        "\n",
        "    def reduce(self, q_val):\n",
        "        return q_val.min(0)[0]\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "        return self.critics[item]\n",
        "\n",
        "    def unstack(self, target=False, single=True, net_id=None):\n",
        "        \"\"\"\n",
        "        Extract the single parameters back to the individual members\n",
        "        target: whether the target ensemble should be extracted or not\n",
        "        single: whether just the first member of the ensemble should be extracted\n",
        "        \"\"\"\n",
        "        params = self.params_target if target else self.params_model\n",
        "        if single and net_id is None:\n",
        "            net_id = 0\n",
        "\n",
        "        for key in params.keys():\n",
        "            if single:\n",
        "                tmp = (\n",
        "                    self.critics[net_id].model\n",
        "                    if not target\n",
        "                    else self.critics[net_id].target\n",
        "                )\n",
        "                for name in key.split(\".\"):\n",
        "                    tmp = getattr(tmp, name)\n",
        "                tmp.data.copy_(params[key][net_id])\n",
        "            else:\n",
        "                for net_id in range(self.n_members):\n",
        "                    tmp = (\n",
        "                        self.critics[net_id].model\n",
        "                        if not target\n",
        "                        else self.critics[net_id].target\n",
        "                    )\n",
        "                    for name in key.split(\".\"):\n",
        "                        tmp = getattr(tmp, name)\n",
        "                    tmp.data.copy_(params[key][net_id])\n",
        "                    if single:\n",
        "                        break\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        assert (\n",
        "            writer is None\n",
        "        ), \"For now nothing else is implemented for the parallel version\"\n",
        "        self.writer = writer\n",
        "        [critic.set_writer(writer) for critic in self.critics]\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        SA = self.expand(th.cat((s, a), -1))\n",
        "        return self.forward_model(self.params_model, self.buffers_model, SA)\n",
        "\n",
        "    @th.no_grad()\n",
        "    def Q_t(self, s, a):\n",
        "        SA = self.expand(th.cat((s, a), -1))\n",
        "        return self.forward_target(self.params_target, self.buffers_target, SA)\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), self.expand(y))\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def update_target(self):\n",
        "        for key in self.params_model.keys():\n",
        "            self.params_target[key].data.mul_(1.0 - self.args.tau)\n",
        "            self.params_target[key].data.add_(\n",
        "                self.args.tau * self.params_model[key].data\n",
        "            )\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_bellman_target(self, r, sp, done, actor):\n",
        "        alpha = actor.log_alpha.exp().detach() if hasattr(actor, \"log_alpha\") else 0\n",
        "        ap, ep = actor.act(sp)\n",
        "        qp = self.Q_t(sp, ap)\n",
        "        qp_t = self.reduce(qp) - alpha * (ep if ep is not None else 0)\n",
        "        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))\n",
        "        return y\n",
        "\n",
        "    def save_params(self, path):\n",
        "        self.unstack(target=False, single=False, net_id=None)\n",
        "        self.unstack(target=True, single=False, net_id=None)\n",
        "        params_list = []\n",
        "        for i in range(len(self.critics)):\n",
        "            params_list.append(self.load_params(self.critics[i]))\n",
        "        torch.save(params_list, path)\n",
        "\n",
        "    def load_params(self, critic):\n",
        "        params = {\n",
        "            \"params_model\": critic.model.state_dict(),\n",
        "            \"params_target\": critic.target.state_dict(),\n",
        "            \"optim\": self.optim.state_dict(),\n",
        "        }\n",
        "        params_th = {\n",
        "            k: v if isinstance(v, torch.Tensor) else v  # Ensure the values are tensors\n",
        "            for k, v in params.items()\n",
        "        }\n",
        "        return params_th\n",
        "\n",
        "class ActorNetProbabilistic(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim_obs,\n",
        "        dim_act,\n",
        "        depth=3,\n",
        "        width=256,\n",
        "        act=\"crelu\",\n",
        "        has_norm=True,\n",
        "        upper_clamp=-2.0,\n",
        "    ):\n",
        "        super().__init__()\n",
        "        self.dim_act = dim_act\n",
        "\n",
        "        self.arch = create_net(dim_obs[0], 2 * dim_act[0], depth, width, act, has_norm)\n",
        "\n",
        "        self.head = SquashedGaussianHead(self.dim_act[0], upper_clamp)\n",
        "\n",
        "    def forward(self, x, is_training=True):\n",
        "        f = self.arch(x)\n",
        "        return self.head(f, is_training)\n",
        "\n",
        "class CriticNet(nn.Module):\n",
        "    def __init__(\n",
        "        self, dim_obs, dim_act, depth=3, width=256, act=\"crelu\", has_norm=True\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0] + dim_act[0], 1, depth, width, act=act, has_norm=has_norm\n",
        "        )\n",
        "\n",
        "    def forward(self, xu):\n",
        "        return self.arch(xu)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fwb7579n0a1B"
      },
      "source": [
        "# Experience Memory\n",
        "Stores past experiences and allows efficient sampling for training an agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tn-ca8fx0cL9"
      },
      "outputs": [],
      "source": [
        "\n",
        "class ExperienceMemoryTorch:\n",
        "    \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n",
        "\n",
        "    field_names = [\"state\", \"action\", \"reward\", \"next_state\", \"terminated\", \"step\"]\n",
        "\n",
        "    def __init__(self, args):\n",
        "        self.device = \"cpu\"\n",
        "        self.buffer_size = 100000\n",
        "        self.dims = args.dims\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self, buffer_size=None):\n",
        "        if buffer_size is not None:\n",
        "            self.buffer_size = buffer_size\n",
        "        self.data_size = 0\n",
        "        self.pointer = 0\n",
        "        self.memory = {\n",
        "            field: th.empty(self.dims[field], device=self.device)\n",
        "            for field in self.field_names\n",
        "        }\n",
        "\n",
        "    def add(self, state, action, reward, next_state, terminated, step):\n",
        "        for field, value in zip(\n",
        "            self.field_names, [state, action, reward, next_state, terminated, step]\n",
        "        ):\n",
        "            self.memory[field][self.pointer] = value\n",
        "        self.pointer = (self.pointer + 1) % self.buffer_size\n",
        "        self.data_size = min(self.data_size + 1, self.buffer_size)\n",
        "\n",
        "    def sample_by_index(self, index):\n",
        "        return tuple(self.memory[field][index] for field in self.field_names)\n",
        "\n",
        "    def sample_by_index_fields(self, index, fields):\n",
        "        if len(fields) == 1:\n",
        "            return self.memory[fields[0]][index]  # return a tensor\n",
        "        return tuple(self.memory[field][index] for field in fields)\n",
        "\n",
        "    def sample_random(self, batch_size):\n",
        "        index = th.randint(self.data_size, (batch_size,))\n",
        "        return self.sample_by_index(index)\n",
        "\n",
        "    @staticmethod\n",
        "    def set_diff_1d(t1, t2, assume_unique=False):\n",
        "        \"\"\"\n",
        "        Set difference of two 1D tensors.\n",
        "        Returns the unique values in t1 that are not in t2.\n",
        "        Source: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors/72898627#72898627\n",
        "        \"\"\"\n",
        "        if not assume_unique:\n",
        "            t1 = torch.unique(t1)\n",
        "            t2 = torch.unique(t2)\n",
        "        return t1[(t1[:, None] != t2).all(dim=1)]\n",
        "\n",
        "    def filter_by_nonterminal_steps_with_horizon(self, horizon):\n",
        "        all_indices = th.arange(self.data_size - horizon + 1)\n",
        "        terminal_indices = th.argwhere(self.memory[\"terminated\"] == True)\n",
        "        if terminal_indices.size == 0:\n",
        "            return all_indices\n",
        "\n",
        "        terminal_with_horizon_indices = th.tensor(\n",
        "            [\n",
        "                th.arange(terminal - horizon + 2, terminal + 1)\n",
        "                for terminal in terminal_indices\n",
        "            ]\n",
        "        ).flatten()\n",
        "        nonterminal_indices = th.setdiff1d(all_indices, terminal_with_horizon_indices)\n",
        "        return nonterminal_indices\n",
        "\n",
        "    def sample_random_sequence_snippet(self, batch_size, sequence_length):\n",
        "        non_terminal_indices = self.filter_by_nonterminal_steps_with_horizon(\n",
        "            sequence_length\n",
        "        )\n",
        "        indices = th.randint(non_terminal_indices, (batch_size,))\n",
        "        output = []\n",
        "        # TODO: Why this loop?\n",
        "        for i in range(sequence_length):\n",
        "            output.append(self.sample_by_index(indices + i))\n",
        "        return output\n",
        "\n",
        "    def sample_all(self):\n",
        "        return self.sample_by_index(range(self.data_size))\n",
        "\n",
        "    def clone(self, other_memory):\n",
        "        self.data_size = other_memory.data_size\n",
        "        self.memory = copy.deepcopy(other_memory.memory)\n",
        "\n",
        "    def extend(self, other_memory):\n",
        "        for field in self.field_names:\n",
        "            self.memory[field].extend(other_memory.memory[field])\n",
        "        self.data_size = len(self.memory[field])\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.data_size\n",
        "\n",
        "    @property\n",
        "    def size(self):\n",
        "        return self.data_size\n",
        "\n",
        "    def save(self, path):\n",
        "        th.save(self.memory, os.path.join(path, \"experience_memory.pt\"))\n",
        "\n",
        "    def get_last_observation(self):\n",
        "        return self.sample_by_index([-1])\n",
        "\n",
        "    def get_last_observations(self, batch_size):\n",
        "        return self.sample_by_index(range(-batch_size, 0))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5375-_K6Dv4f"
      },
      "source": [
        "# Agent\n",
        "Defines a reinforcement learning agent that interacts with the environment, stores experiences, and updates its models using soft or hard updates."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XFQOR6k3DwqW"
      },
      "outputs": [],
      "source": [
        "\n",
        "args, unknown = parser.parse_known_args()\n",
        "\n",
        "def totorch(x, dtype=th.float32, device=\"cpu\"):\n",
        "    return th.as_tensor(x, dtype=dtype, device=device)\n",
        "\n",
        "\n",
        "class Agent(nn.Module):\n",
        "    def __init__(self, env, args):\n",
        "        super(Agent, self).__init__()\n",
        "        self.args = args\n",
        "        self.device =\"cpu\" \n",
        "        args.buffer_size=100000\n",
        "        self.tau = 0.005\n",
        "        self.gamma = 0.99 \n",
        "        self.env = env\n",
        "        self.dim_obs, self.dim_act = (\n",
        "            self.env.observation_space.shape,\n",
        "            self.env.action_space.shape,\n",
        "        )\n",
        "        print(f\"INFO: dim_obs = {self.dim_obs} dim_act = {self.dim_act}\")\n",
        "        self.dim_obs_flat, self.dim_act_flat = np.prod(self.dim_obs), np.prod(\n",
        "            self.dim_act\n",
        "        )\n",
        "        self._u_min = totorch(self.env.action_space.low, device=self.device)\n",
        "        self._u_max = totorch(self.env.action_space.high, device=self.device)\n",
        "        self._x_min = totorch(self.env.observation_space.low, device=self.device)\n",
        "        self._x_max = totorch(self.env.observation_space.high, device=self.device)\n",
        "\n",
        "        self._gamma = self.gamma\n",
        "        self._tau = self.tau\n",
        "\n",
        "        args.dims = {\n",
        "            \"state\": (args.buffer_size, self.dim_obs_flat),\n",
        "            \"action\": (args.buffer_size, self.dim_act_flat),\n",
        "            \"next_state\": (args.buffer_size, self.dim_obs_flat),\n",
        "            \"reward\": (args.buffer_size),\n",
        "            \"terminated\": (args.buffer_size),\n",
        "            \"step\": (args.buffer_size),\n",
        "        }\n",
        "\n",
        "        self.experience_memory = ExperienceMemoryTorch(args)\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def _soft_update(self, local_model, target_model):\n",
        "        for target_param, local_param in zip(\n",
        "            target_model.parameters(), local_model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def _hard_update(self, local_model, target_model):\n",
        "\n",
        "        for target_param, local_param in zip(\n",
        "            target_model.parameters(), local_model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    def learn(self, max_iter=1):\n",
        "        raise NotImplementedError(f\"learn() not implemented for {self.name} agent\")\n",
        "\n",
        "    def select_action(self, warmup=False, exploit=False):\n",
        "        raise NotImplementedError(\n",
        "            f\"select_action() not implemented for {self.name} agent\"\n",
        "        )\n",
        "\n",
        "    def store_transition(self, s, a, r, sp, terminated, truncated, step):\n",
        "        self.experience_memory.add(s, a, r, sp, terminated, step)\n",
        "        self.actor.set_episode_status(terminated or truncated)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Actor-Critic Agent\n",
        "Creates an actor-critic agent that learns by updating the critic network and adjusting the actor network based on feedback from the critic."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "class ActorCritic(Agent):\n",
        "    _agent_name = \"AC\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        env,\n",
        "        args,\n",
        "        actor_nn,\n",
        "        critic_nn,\n",
        "        CriticEnsembleType=CriticEnsemble,\n",
        "        ActorType=Actor,\n",
        "    ):\n",
        "        super(ActorCritic, self).__init__(env, args)\n",
        "        self.critics = CriticEnsembleType(critic_nn, args, self.dim_obs, self.dim_act)\n",
        "        self.actor = ActorType(actor_nn, args, self.dim_obs, self.dim_act)\n",
        "        self.n_iter = 0\n",
        "        self.policy_delay = 1\n",
        "        self.args.batch_size = 256\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "        self.actor.set_writer(writer)\n",
        "        self.critics.set_writer(writer)\n",
        "\n",
        "    def learn(self, max_iter=5):\n",
        "        if self.args.batch_size > len(self.experience_memory):\n",
        "            return None\n",
        "\n",
        "        for ii in range(max_iter):\n",
        "            s, a, r, sp, done, step = self.experience_memory.sample_random(\n",
        "                self.args.batch_size\n",
        "            )\n",
        "            y = self.critics.get_bellman_target(r, sp, done, self.actor)\n",
        "            self.critics.update(s, a, y)\n",
        "\n",
        "            if self.n_iter % self.policy_delay == 0:\n",
        "                self.actor.update(s, self.critics)\n",
        "            self.critics.update_target()\n",
        "            self.n_iter += 1\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def select_action(self, s, is_training=True):\n",
        "        # s to device\n",
        "        a, _ = self.actor.act(s, is_training=is_training)\n",
        "        return a\n",
        "\n",
        "    def Q_value(self, s, a):\n",
        "        \n",
        "        if len(s.shape) == 1:\n",
        "            s = s[None]\n",
        "        if len(a.shape) == 1:\n",
        "            a = a[None]\n",
        "        if isinstance(self.critics, ParallelCritics):\n",
        "            self.critics.unstack(target=False, single=True)\n",
        "        \n",
        "        q = self.critics[0].Q(s, a)\n",
        "        return q.item()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DXN8_ffH0k0Z"
      },
      "source": [
        "# REDQ\n",
        "REDQ agent used for training process of actor critic agent in order to collect environment transisions for constructing train and test datasets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "ffwBGyfi0lyy"
      },
      "outputs": [],
      "source": [
        "class REDQCritics(Critics):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=Critic):\n",
        "        super().__init__(arch, args, n_state, n_action, critictype)\n",
        "        self.args = args\n",
        "        self.n_in_target = 2\n",
        "        \n",
        "        \n",
        "\n",
        "    def reduce(self, q_val_list):\n",
        "        i_targets = torch.randint(0, self.n_members, (self.n_in_target,))\n",
        "        return torch.stack([q_val_list[i] for i in i_targets], dim=-1).min(-1)[0]\n",
        "        \n",
        "class RandomEnsembleDoubleQLearning(ActorCritic):\n",
        "    _agent_name = \"REDQ\"\n",
        "\n",
        "    def __init__(self, env, args, actor_nn=ActorNetProbabilistic, critic_nn=CriticNet):\n",
        "        super().__init__(\n",
        "            env,\n",
        "            args,\n",
        "            actor_nn,\n",
        "            critic_nn,\n",
        "            CriticEnsembleType=REDQCritics,\n",
        "            ActorType=SoftActor,\n",
        "        )\n",
        "        self.args.explore_noise = 0.1\n",
        "        self.actor.c = self.args.explore_noise\n",
        "        self.args = args"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x4t4hxd6xNUO"
      },
      "source": [
        "# Control Experiments\n",
        "Manage the training and evaluation of an agent in a control task, and its interactions with the environment.\n",
        "\n",
        "Assign \"saveparams\" to True in orde to save the policy parameters during agent training. After training completed, put \"validationrounds\"  True and save 200 validation rounds to collect dataset (100 for training and 100 for test)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X6664Fo4xOqC"
      },
      "outputs": [],
      "source": [
        "from types import SimpleNamespace\n",
        "\n",
        "class ControlExperiment(Experiment):\n",
        "    def __init__(self):\n",
        "        super(ControlExperiment, self).__init__()\n",
        "        self.args = SimpleNamespace()\n",
        "        self.args.verbose=False\n",
        "        self.eval_reward=0\n",
        "        self.late_eval_reward=0\n",
        "        self.is_break=False\n",
        "        self.device_str =\"cpu\"\n",
        "        self.optimizer_args = {\"lr\": 4e-3}\n",
        "        self.n_total_steps = 0\n",
        "        self.args.max_steps = 300000\n",
        "        self.args.eval_frequency=2000\n",
        "        self.args.eval_episodes=10\n",
        "        self.args.gamma=0.99\n",
        "        self.args.warmup_steps=10000\n",
        "        self.args.learn_frequency=1\n",
        "        self.args.max_iter=5\n",
        "        self.args.n_critics=10\n",
        "        self.args.alpha=1\n",
        "        self.args.progress=False\n",
        "        self.args.reset_frequency=0\n",
        "        self.args.depth_critic=3\n",
        "        self.args.width_critic=256\n",
        "        self.args.device = \"cpu\"\n",
        "        self.args.saveparams= False \n",
        "        self.args.validationrounds= False \n",
        "        \n",
        "\n",
        "    def train(self):\n",
        "        time_start = time.time()\n",
        "        information_dict = {\n",
        "            \"episode_rewards\": th.zeros(1000000),\n",
        "            \"episode_steps\": th.zeros(1000000),\n",
        "            \"step_rewards\": np.empty((2 * self.args.max_steps), dtype=object),\n",
        "        }\n",
        "\n",
        "        s, _ = self.env.reset()\n",
        "        s = totorch(s) \n",
        "        r_cum = np.zeros(1)\n",
        "        episode = 0\n",
        "        e_step = 0\n",
        "        self.last_saved_step = 0\n",
        "\n",
        "        for step in tqdm(\n",
        "            range(self.args.max_steps), leave=True, disable=not self.args.progress\n",
        "        ):\n",
        "            e_step += 1\n",
        "\n",
        "            if (\n",
        "                step > self.args.warmup_steps\n",
        "                and self.args.reset_frequency > 0\n",
        "                and step % self.args.reset_frequency == 0\n",
        "            ):\n",
        "                self.agent.critics.reset()\n",
        "                self.agent.to(self.args.device)\n",
        "\n",
        "            if step % self.args.eval_frequency == 0:\n",
        "                self.eval(step)\n",
        "\n",
        "            if step < self.args.warmup_steps:\n",
        "                a = self.env.action_space.sample()\n",
        "                a = totorch(np.clip(a, -1.0, 1.0), device=self.args.device)\n",
        "\n",
        "            else:\n",
        "                \n",
        "                a = self.agent.select_action(s.to(\"cpu\")).clip(-1.0, 1.0)\n",
        "\n",
        "            sp, r, done, truncated, info = self.env.step(tonumpy(a))\n",
        "            sp = totorch(sp, device=self.args.device)\n",
        "\n",
        "            if self.args.verbose and \"sp\" in self.args.env:\n",
        "                print(\"X pos: \", info[\"x_pos\"], \"Action norm: \", info[\"action_norm\"])\n",
        "                # TODO: Write this instead into a file!\n",
        "\n",
        "            self.agent.store_transition(s, a, r, sp, done, truncated, step + 1)\n",
        "            #self.agent.to(self.args.device)\n",
        "\n",
        "            information_dict[\"step_rewards\"][self.n_total_steps + step] = (\n",
        "                episode,\n",
        "                step,\n",
        "                r,\n",
        "            )\n",
        "\n",
        "            s = sp  # Update state\n",
        "            r_cum += r  # Update cumulative reward\n",
        "\n",
        "            if (\n",
        "                step >= self.args.warmup_steps\n",
        "                and (step % self.args.learn_frequency) == 0\n",
        "            ):\n",
        "                #self.agent.to(self.args.device)\n",
        "                self.agent.learn(max_iter=self.args.max_iter)\n",
        "                \n",
        "            if self.args.saveparams:\n",
        "                next_save_step = ((self.last_saved_step // 50000) + 1) * 50000  # Compute next 50000 milestone\n",
        "                \n",
        "                if self.last_saved_step < next_save_step <= step or step == self.args.max_steps - 1:  # First step after passing 50000, 100000, etc.\n",
        "                    self.agent.critics.save_params(\n",
        "                        f\"_logs/{self.args.env}/{self.args.model}/seed_0{self.args.seed}/params_{step}.pth\"\n",
        "                    )\n",
        "                    self.agent.actor.save_actor_params(\n",
        "                        f\"_logs/{self.args.env}/{self.args.model}/seed_0{self.args.seed}/Actor_params_{step}.pth\"\n",
        "                    )\n",
        "                    self.last_saved_step = next_save_step  # Update the last saved milestone\n",
        "\n",
        "            if done or truncated:\n",
        "\n",
        "                information_dict[\"episode_rewards\"][episode] = r_cum.item()\n",
        "                information_dict[\"episode_steps\"][episode] = step\n",
        "                print('Episode:', episode, ' Reward: %.3f' % np.mean(r_cum), 'N-steps: %d' % step)\n",
        "                s, _ = self.env.reset()\n",
        "                s = totorch(s, device=self.args.device)\n",
        "                r_cum = np.zeros(1)\n",
        "                episode += 1\n",
        "                e_step = 0\n",
        "\n",
        "\n",
        "        self.eval(step)\n",
        "        time_end = time.time()\n",
        "    \n",
        "    \n",
        "    \n",
        "    @torch.no_grad()\n",
        "    def eval(self, n_step):\n",
        "        self.agent.eval()\n",
        "        results = th.zeros(self.args.eval_episodes)\n",
        "        q_values = th.zeros((self.args.eval_episodes, 2))\n",
        "        avg_reward = th.zeros(self.args.eval_episodes)\n",
        "        collect_infos = {}\n",
        "        performance_eval_dict = {\n",
        "            \"episode_info\": np.empty((2 * self.args.max_steps), dtype=object),\n",
        "            \"trajectory\": [],\n",
        "        }\n",
        "\n",
        "        for episode in range(self.args.eval_episodes):\n",
        "            collect_infos[episode] = []\n",
        "            s, info = self.eval_env.reset()\n",
        "            s = totorch(s)\n",
        "            step = 0\n",
        "            a = self.agent.select_action(s, is_training=False)\n",
        "            q_values[episode] = self.agent.Q_value(totorch(s, device=self.args.device), totorch(a, device=self.args.device))\n",
        "            done = False\n",
        "\n",
        "            while not done:\n",
        "                s = totorch(s)\n",
        "                a = self.agent.select_action(s, is_training=False)\n",
        "\n",
        "                sp, r, term, trunc, info = self.eval_env.step(tonumpy(a))\n",
        "                collect_infos[episode].append(info)\n",
        "\n",
        "                if self.args.validationrounds:\n",
        "                    performance_eval_dict[\"trajectory\"].append(\n",
        "                        (episode, step, s, a, sp, r, term, trunc, info)\n",
        "                    )\n",
        "\n",
        "                done = term or trunc\n",
        "                s = totorch(sp, device=self.args.device)\n",
        "                results[episode] += r\n",
        "                avg_reward[episode] += self.args.gamma**step * r\n",
        "                step += 1\n",
        "\n",
        "            if self.args.validationrounds:\n",
        "                performance_eval_dict[\"episode_info\"][episode] = (\n",
        "                    episode,\n",
        "                    avg_reward[episode],\n",
        "                )\n",
        "\n",
        "        self.agent.actor.states = []\n",
        "\n",
        "        self.agent.train()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Run All"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 399
        },
        "id": "BRFN2y0505pX",
        "outputId": "3d0fc500-6be0-4910-b140-fe099cde25ed"
      },
      "outputs": [],
      "source": [
        "exp = ControlExperiment()\n",
        "exp.train()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "dm",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.20"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
