{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Import Necessary Packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "GUrtasdlwqId"
      },
      "outputs": [],
      "source": [
        "import argparse\n",
        "import copy\n",
        "import math\n",
        "import os\n",
        "import random\n",
        "import time\n",
        "from collections import deque\n",
        "from typing import Dict, Optional, OrderedDict, Tuple\n",
        "\n",
        "import gymnasium as gym\n",
        "from gymnasium import core, spaces\n",
        "from gymnasium.spaces import Box, Dict\n",
        "from gymnasium.wrappers import RescaleAction\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from scipy.stats import norm\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.nn.init as init\n",
        "from torch.nn.parameter import Parameter\n",
        "from torch.nn.modules.utils import _pair\n",
        "from torch.distributions import Normal, TransformedDistribution\n",
        "from torch.distributions.transforms import TanhTransform\n",
        "\n",
        "import dm_env\n",
        "from dm_control import suite\n",
        "\n",
        "# Argument parser\n",
        "parser = argparse.ArgumentParser()\n",
        "parser.add_argument(\"-f\", required=False)  # For compatibility with notebooks\n",
        "args, unknown = parser.parse_known_args()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rYywNtri14tV"
      },
      "source": [
        "# Get the Model "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "uG5PZt1Y13fT"
      },
      "outputs": [],
      "source": [
        "def get_model( env):\n",
        "    model_name = 'DEA'\n",
        "    return DEASoftActorCritic (env, args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZci7w8L3v8J"
      },
      "source": [
        "# DMControl Environment\n",
        " Define a wrapper that converts DeepMind Control Suite (DMC) environments into a Gym-compatible format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ndagaTFC3xf_",
        "outputId": "1e34eac7-de16-4151-aa33-2f10a0907165"
      },
      "outputs": [],
      "source": [
        "\n",
        "TimeStep = Tuple[np.ndarray, float, bool, bool, dict]\n",
        "\n",
        "def dmc_spec2gym_space(spec):\n",
        "    if isinstance(spec, OrderedDict) or isinstance(spec, dict):\n",
        "        spec = copy.copy(spec)\n",
        "        for k, v in spec.items():\n",
        "            spec[k] = dmc_spec2gym_space(v)\n",
        "        return spaces.Dict(spec)\n",
        "    elif isinstance(spec, dm_env.specs.BoundedArray):\n",
        "        return spaces.Box(low=spec.minimum,\n",
        "                          high=spec.maximum,\n",
        "                          shape=spec.shape,\n",
        "                          dtype=spec.dtype)\n",
        "    elif isinstance(spec, dm_env.specs.Array):\n",
        "        return spaces.Box(low=-float('inf'),\n",
        "                          high=float('inf'),\n",
        "                          shape=spec.shape,\n",
        "                          dtype=spec.dtype)\n",
        "    else:\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "class DMCEnv(core.Env):\n",
        "    def __init__(self,\n",
        "                 domain_name: Optional[str] = None,\n",
        "                 task_name: Optional[str] = None,\n",
        "                 env: Optional[dm_env.Environment] = None,\n",
        "                 task_kwargs: Optional[Dict] = {},\n",
        "                 environment_kwargs=None):\n",
        "        assert 'random' in task_kwargs, 'Please specify a seed, for deterministic behaviour.'\n",
        "        assert (\n",
        "            env is not None\n",
        "            or (domain_name is not None and task_name is not None)\n",
        "        ), 'You must provide either an environment or domain and task names.'\n",
        "\n",
        "        if env is None:\n",
        "            env = suite.load(\n",
        "                domain_name=domain_name,\n",
        "                task_name=task_name,\n",
        "                task_kwargs=task_kwargs,\n",
        "                environment_kwargs=environment_kwargs,\n",
        "                visualize_reward=True\n",
        "            )\n",
        "\n",
        "        self._env = env\n",
        "        self.domain_name = domain_name\n",
        "        self.task_name = task_name\n",
        "        self.action_space = dmc_spec2gym_space(self._env.action_spec())\n",
        "\n",
        "        self.observation_space = dmc_spec2gym_space(\n",
        "            self._env.observation_spec())\n",
        "\n",
        "    def __getattr__(self, name):\n",
        "        return getattr(self._env, name)\n",
        "\n",
        "    def step(self, action: np.ndarray) -> TimeStep:\n",
        "        assert self.action_space.contains(action)\n",
        "\n",
        "        time_step = self._env.step(action)\n",
        "        reward = time_step.reward or 0\n",
        "        done = time_step.last()\n",
        "        obs = time_step.observation\n",
        "\n",
        "        info  = {}\n",
        "        trunc = done and (time_step.discount == 1.0)\n",
        "        term = done and (time_step.discount != 1.0)\n",
        "        if trunc:\n",
        "            info['TimeLimit.truncated'] = True\n",
        "        return obs, reward, term, trunc, info\n",
        "\n",
        "    def reset(self, seed=None, options=None):\n",
        "        super().reset(seed=seed)\n",
        "        time_step = self._env.reset()\n",
        "        info = {}\n",
        "        return time_step.observation, info\n",
        "\n",
        "    def render(self,\n",
        "               mode='rgb_array',\n",
        "               height: int = 84,\n",
        "               width: int = 84,\n",
        "               camera_id: int = 0):\n",
        "        assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode\n",
        "        return self._env.physics.render(height=height,\n",
        "                                        width=width,\n",
        "                                        camera_id=camera_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Make Environment\n",
        "Defines wrappers and utilities to preprocess Gym environments. The `make_env()` function sets up a DeepMind Control Suite environment with various preprocessing options.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "vWY4CKP0171s"
      },
      "outputs": [],
      "source": [
        "\n",
        "class SinglePrecision(gym.ObservationWrapper):\n",
        "    def __init__(self, env):\n",
        "        super().__init__(env)\n",
        "\n",
        "        if isinstance(self.observation_space, Box):\n",
        "            obs_space = self.observation_space\n",
        "            self.observation_space = Box(obs_space.low, obs_space.high,\n",
        "                                         obs_space.shape)\n",
        "        elif isinstance(self.observation_space, Dict):\n",
        "            obs_spaces = copy.copy(self.observation_space.spaces)\n",
        "            for k, v in obs_spaces.items():\n",
        "                obs_spaces[k] = Box(v.low, v.high, v.shape)\n",
        "            self.observation_space = Dict(obs_spaces)\n",
        "        else:\n",
        "            raise NotImplementedError\n",
        "\n",
        "    def observation(self, observation: np.ndarray) -> np.ndarray:\n",
        "        if isinstance(observation, np.ndarray):\n",
        "            return observation.astype(np.float32)\n",
        "        elif isinstance(observation, dict):\n",
        "            observation = copy.copy(observation)\n",
        "            for k, v in observation.items():\n",
        "                observation[k] = v.astype(np.float32)\n",
        "            return observation\n",
        "        \n",
        "class FlattenAction(gym.ActionWrapper):\n",
        "    \"\"\"Action wrapper that flattens the action.\"\"\"\n",
        "\n",
        "    def __init__(self, env):\n",
        "        super(FlattenAction, self).__init__(env)\n",
        "        self.action_space = gym.spaces.utils.flatten_space(self.env.action_space)\n",
        "\n",
        "    def action(self, action):\n",
        "        return gym.spaces.utils.unflatten(self.env.action_space, action)\n",
        "\n",
        "    def reverse_action(self, action):\n",
        "        return gym.spaces.utils.flatten(self.env.action_space, action)\n",
        "\n",
        "def make_env(env_name: str,\n",
        "             seed: int,\n",
        "             save_folder: Optional[str] = None,\n",
        "             add_episode_monitor: bool = True,\n",
        "             action_repeat: int = 1,\n",
        "             frame_stack: int = 1,\n",
        "             from_pixels: bool = False,\n",
        "             pixels_only: bool = True,\n",
        "             image_size: int = 84,\n",
        "             sticky: bool = False,\n",
        "             gray_scale: bool = False,\n",
        "             flatten: bool = True,\n",
        "             terminate_when_unhealthy: bool = True,\n",
        "             action_concat: int = 1,\n",
        "             obs_concat: int = 1,\n",
        "             continuous: bool = True,\n",
        "             ) -> gym.Env:\n",
        "\n",
        "    env_ids = list(gym.envs.registry.keys())\n",
        "\n",
        "    domain_name, task_name = env_name.split('-')\n",
        "    env = DMCEnv(domain_name=domain_name, task_name=task_name, task_kwargs={'random': seed})\n",
        "\n",
        "    if flatten and isinstance(env.observation_space, gym.spaces.Dict):\n",
        "        env = gym.wrappers.FlattenObservation(env)\n",
        "        env = FlattenAction(env)\n",
        "\n",
        "    if continuous:\n",
        "        env = RescaleAction(env, -1.0, 1.0)\n",
        "\n",
        "    env = SinglePrecision(env)\n",
        "    env.reset(seed=seed)\n",
        "    env.action_space.seed(seed)\n",
        "    env.observation_space.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    return env\n",
        "\n",
        "class Experiment(object):\n",
        "    def __init__(self):\n",
        "        self.args = args\n",
        "        self.n_total_steps = 0\n",
        "        self.max_steps = 100000\n",
        "        self.env = make_env('cartpole-swingup', 1)\n",
        "        self.eval_env = make_env('cartpole-swingup', 101)\n",
        "        self.agent = get_model( self.env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VmIkEsMzyPdc"
      },
      "source": [
        "# Architectures \n",
        "This cell includes all the network design architectures, such as the critic and actor networks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zfsulbl1yO81"
      },
      "outputs": [],
      "source": [
        "import torch as th\n",
        "from tqdm import tqdm\n",
        "from torch import nn, func as thf\n",
        "import itertools\n",
        "def tonumpy(x):\n",
        "    return x.data.cpu().numpy() \n",
        "\n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super(Critic, self).__init__()\n",
        "        self.args = args\n",
        "        args.n_hidden = 256\n",
        "        args.learning_rate = 3e-4\n",
        "        args.tau = 0.005\n",
        "        args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self.model = arch(n_state, n_action, args.n_hidden).to(args.device)\n",
        "        self.target = arch(n_state, n_action, args.n_hidden).to(args.device)\n",
        "        self.init_target()\n",
        "        self.loss = nn.MSELoss()\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)\n",
        "        self.iter = 0\n",
        "        self.args.tau = 0.005\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    @th.no_grad()\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        return self.model(s, a)\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        return self.target(s, a)\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), y)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "class Critics(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=Critic):\n",
        "        super().__init__()\n",
        "        self.args = args\n",
        "        self.n_members = 2\n",
        "        self.arch = arch\n",
        "        self.n_state = n_state\n",
        "        self.n_action = n_action\n",
        "        self.critictype = critictype\n",
        "        self.iter = 0\n",
        "        self.loss = self.critictype(\n",
        "            self.arch, self.args, self.n_state, self.n_action\n",
        "        ).loss\n",
        "        self.optim = self.critictype(\n",
        "            self.arch, self.args, self.n_state, self.n_action\n",
        "        ).optim\n",
        "\n",
        "        # Helperfunctions\n",
        "        self.expand = lambda x: (\n",
        "            x.expand(self.n_members, *x.shape) if len(x.shape) < 3 else x\n",
        "        )\n",
        "        # self.reduce = lambda q_val: q_val.min(0)[0]\n",
        "\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self):\n",
        "        self.critics = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action)\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "\n",
        "        self.critics_model = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action).model\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "        self.critics_target = [\n",
        "            self.critictype(self.arch, self.args, self.n_state, self.n_action).target\n",
        "            for _ in range(self.n_members)\n",
        "        ]\n",
        "\n",
        "        self.params_model, self.buffers_model = thf.stack_module_state(\n",
        "            self.critics_model\n",
        "        )\n",
        "        self.params_target, self.buffers_target = thf.stack_module_state(\n",
        "            self.critics_target\n",
        "        )\n",
        "\n",
        "        self.base_model = copy.deepcopy(self.critics[0].model).to(\"meta\")\n",
        "        self.base_target = copy.deepcopy(self.critics[0].target).to(\"meta\")\n",
        "\n",
        "        def _fmodel(base_model, params, buffers, x):\n",
        "            return thf.functional_call(base_model, (params, buffers), (x,))\n",
        "\n",
        "        self.forward_model = thf.vmap(lambda p, b, x: _fmodel(self.base_model, p, b, x))\n",
        "        self.forward_target = thf.vmap(\n",
        "            lambda p, b, x: _fmodel(self.base_target, p, b, x)\n",
        "        )\n",
        "        self.optim = th.optim.Adam(\n",
        "            self.params_model.values(), lr=self.args.learning_rate\n",
        "        )\n",
        "\n",
        "    def reduce(self, q_val):\n",
        "        return q_val.min(0)[0]\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "        return self.critics[item]\n",
        "\n",
        "    def unstack(self, target=False, single=True, net_id=None):\n",
        "        \"\"\"\n",
        "        Extract the single parameters back to the individual members\n",
        "        target: whether the target ensemble should be extracted or not\n",
        "        single: whether just the first member of the ensemble should be extracted\n",
        "        \"\"\"\n",
        "        params = self.params_target if target else self.params_model\n",
        "        if single and net_id is None:\n",
        "            net_id = 0\n",
        "\n",
        "        for key in params.keys():\n",
        "            if single:\n",
        "                tmp = (\n",
        "                    self.critics[net_id].model\n",
        "                    if not target\n",
        "                    else self.critics[net_id].target\n",
        "                )\n",
        "                for name in key.split(\".\"):\n",
        "                    tmp = getattr(tmp, name)\n",
        "                tmp.data.copy_(params[key][net_id])\n",
        "            else:\n",
        "                for net_id in range(self.n_members):\n",
        "                    tmp = (\n",
        "                        self.critics[net_id].model\n",
        "                        if not target\n",
        "                        else self.critics[net_id].target\n",
        "                    )\n",
        "                    for name in key.split(\".\"):\n",
        "                        tmp = getattr(tmp, name)\n",
        "                    tmp.data.copy_(params[key][net_id])\n",
        "                    if single:\n",
        "                        break\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        assert (\n",
        "            writer is None\n",
        "        ), \"For now nothing else is implemented for the parallel version\"\n",
        "        self.writer = writer\n",
        "        [critic.set_writer(writer) for critic in self.critics]\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        SA = self.expand(th.cat((s, a), -1))\n",
        "        return self.forward_model(self.params_model, self.buffers_model, SA)\n",
        "\n",
        "    @th.no_grad()\n",
        "    def Q_t(self, s, a):\n",
        "        SA = self.expand(th.cat((s, a), -1))\n",
        "        return self.forward_target(self.params_target, self.buffers_target, SA)\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), self.expand(y))\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def update_target(self):\n",
        "        for key in self.params_model.keys():\n",
        "            self.params_target[key].data.mul_(1.0 - self.args.tau)\n",
        "            self.params_target[key].data.add_(\n",
        "                self.args.tau * self.params_model[key].data\n",
        "            )\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_bellman_target(self, r, sp, done, actor):\n",
        "        alpha = actor.log_alpha.exp().detach() if hasattr(actor, \"log_alpha\") else 0\n",
        "        ap, ep = actor.act(sp)\n",
        "        qp = self.Q_t(sp, ap)\n",
        "        qp_t = self.reduce(qp) - alpha * (ep if ep is not None else 0)\n",
        "        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))\n",
        "        return y\n",
        "\n",
        "    def save_params(self, path):\n",
        "        self.unstack(target=False, single=False, net_id=None)\n",
        "        self.unstack(target=True, single=False, net_id=None)\n",
        "        params_list = []\n",
        "        for i in range(len(self.critics)):\n",
        "            params_list.append(self.load_params(self.critics[i]))\n",
        "        torch.save(params_list, path)\n",
        "\n",
        "    def load_params(self, critic):\n",
        "        params = {\n",
        "            \"params_model\": critic.model.state_dict(),\n",
        "            \"params_target\": critic.target.state_dict(),\n",
        "            \"optim\": self.optim.state_dict(),\n",
        "        }\n",
        "        params_th = {\n",
        "            k: v if isinstance(v, torch.Tensor) else v  # Ensure the values are tensors\n",
        "            for k, v in params.items()\n",
        "        }\n",
        "        return params_th\n",
        "\n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super().__init__()\n",
        "        self.args = args\n",
        "        args.n_hidden = 256\n",
        "        args.learning_rate = 3e-4\n",
        "        args.tau = 0.005\n",
        "        args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self.arch = arch\n",
        "        self.model = arch(n_state, n_action, args.n_hidden).to(args.device)\n",
        "        self.target = arch(n_state, n_action, args.n_hidden).to(args.device)\n",
        "        self.init_target()\n",
        "        self.loss = nn.MSELoss()\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)\n",
        "        self.iter = 0\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    @th.no_grad()\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        return self.model(s, a)\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        return self.target(s, a)\n",
        "\n",
        "    def update(self, s, a, y):  # y denotes bellman target\n",
        "        self.optim.zero_grad()\n",
        "        loss = self.loss(self.Q(s, a), y)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.iter += 1\n",
        "\n",
        "class CriticEnsemble(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=Critic):\n",
        "        super(CriticEnsemble, self).__init__()\n",
        "        self.n_elements = args.n_critics\n",
        "        self.args = args\n",
        "        self.critics = [\n",
        "            critictype(arch, args, n_state, n_action) for _ in range(self.n_elements)\n",
        "        ]\n",
        "        self.gamma=0.99\n",
        "        self.iter = 0\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "        return self.critics[item]\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "        [critic.set_writer(writer) for critic in self.critics]\n",
        "\n",
        "    def Q(self, s, a):\n",
        "        return [critic.Q(s, a) for critic in self.critics]\n",
        "\n",
        "    def Q_t(self, s, a):\n",
        "        return [critic.Q_t(s, a) for critic in self.critics]\n",
        "\n",
        "    def update(self, s, a, y):\n",
        "        [critic.update(s, a, y) for critic in self.critics]\n",
        "        self.iter += 1\n",
        "\n",
        "    def update_target(self):\n",
        "        [critic.update_target() for critic in self.critics]\n",
        "\n",
        "    def reduce(self, q_val_list):\n",
        "        return torch.stack(q_val_list, dim=-1).min(dim=-1)[0]\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_bellman_target(self, r, sp, done, actor):\n",
        "        alpha = actor.log_alpha.exp().detach() if hasattr(actor, \"log_alpha\") else 0\n",
        "        ap, ep = actor.act(sp)\n",
        "        qp = self.Q_t(sp, ap)\n",
        "        if ep is None:\n",
        "            ep = 0\n",
        "        qp_t = self.reduce(qp) - alpha * ep\n",
        "        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))\n",
        "        return y\n",
        " \n",
        "\n",
        "class Actor(nn.Module):\n",
        "    def __init__(self, arch, args, n_state, n_action, has_target=False):\n",
        "        super().__init__()\n",
        "        self.model = arch(\n",
        "            n_state,\n",
        "            n_action,\n",
        "            depth=3,\n",
        "            width=256,\n",
        "            act=\"crelu\",\n",
        "            has_norm=not False,\n",
        "        )\n",
        "        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)\n",
        "        self.args = args\n",
        "        self.has_target = has_target\n",
        "        self.iter = 0\n",
        "        self.is_episode_end = False\n",
        "        self.states = []\n",
        "        self.print_freq = 500\n",
        "\n",
        "        if has_target:\n",
        "            self.target = arch(\n",
        "                n_state,\n",
        "                n_action,\n",
        "                depth=3,\n",
        "                width=256,\n",
        "                act=\"crelu\",\n",
        "                has_norm=not False,\n",
        "            )\n",
        "            self.init_target()\n",
        "\n",
        "    def init_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def act(self, s, is_training=True):\n",
        "        a, e = self.model(\n",
        "            s, is_training=is_training\n",
        "        ) \n",
        "\n",
        "        if is_training:\n",
        "            if self.args.verbose and self.iter % self.print_freq == 0:\n",
        "                self.states.append(tonumpy(s))\n",
        "        return a, e\n",
        "    \n",
        "    def act_target(self, s):\n",
        "        a, e = self.target(s)\n",
        "        return a, e\n",
        "\n",
        "    def set_episode_status(self, is_end):\n",
        "        self.is_episode_end = is_end\n",
        "\n",
        "    @th.no_grad()\n",
        "    def update_target(self):\n",
        "        for target_param, local_param in zip(\n",
        "            self.target.parameters(), self.model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.tau)\n",
        "            target_param.data.add_(self.tau * local_param.data)\n",
        "\n",
        "    def loss(self, s, critics):\n",
        "        a, _ = self.act(s)\n",
        "        q_list = critics.Q(s, a)\n",
        "        q = critics.reduce(q_list)\n",
        "        return (-q).mean(), None\n",
        "\n",
        "    def update(self, s, critics):\n",
        "        self.optim.zero_grad()\n",
        "        loss, _ = self.loss(s, critics)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "\n",
        "        if self.has_target:\n",
        "            self.update_target()\n",
        "\n",
        "        self.iter += 1\n",
        "\n",
        "\n",
        "class SoftActor(Actor):\n",
        "    def __init__(self, arch, args, n_state, n_action, has_target=False):\n",
        "        super(SoftActor, self).__init__(arch, args, n_state, n_action, has_target)\n",
        "        self.H_target = -n_action[0]\n",
        "        args.learning_rate=3e-4\n",
        "        args.alpha = 1\n",
        "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self.log_alpha = torch.tensor(\n",
        "            math.log(args.alpha), requires_grad=True , device=self.device\n",
        "        )\n",
        "        self.optim_alpha = torch.optim.Adam([self.log_alpha], args.learning_rate)\n",
        "\n",
        "    def update_alpha(self, e):\n",
        "        self.optim_alpha.zero_grad()\n",
        "        alpha_loss = -(self.log_alpha.exp() * (e + self.H_target).detach()).mean()\n",
        "        alpha_loss.backward()\n",
        "        self.optim_alpha.step()\n",
        "\n",
        "    def loss(self, s, critics):\n",
        "        a, e = self.act(s)\n",
        "        q_list = critics.Q(s, a)\n",
        "        q = critics.reduce(q_list)\n",
        "        return (-q + self.log_alpha.exp() * e).mean(), e\n",
        "\n",
        "    def update(self, s, critics):\n",
        "        self.optim.zero_grad()\n",
        "        loss, e = self.loss(s, critics)\n",
        "        loss.backward()\n",
        "        self.optim.step()\n",
        "        self.update_alpha(e)\n",
        "        self.iter += 1\n",
        "\n",
        "\n",
        "class CReLU(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = torch.cat((x, -x), -1)\n",
        "        return F.relu(x)\n",
        "\n",
        "\n",
        "def create_net(d_in, d_out, depth, width, act=\"crelu\", has_norm=True, n_elements=1):\n",
        "    assert depth > 0, \"Need at least one layer\"\n",
        "\n",
        "    double_width = False\n",
        "    if act == \"crelu\":\n",
        "        act = CReLU\n",
        "        double_width = True\n",
        "    elif act == \"relu\":\n",
        "        act = nn.ReLU\n",
        "    else:\n",
        "        raise NotImplementedError(f\"{act} is not implemented\")\n",
        "\n",
        "    if depth == 1:\n",
        "        arch = nn.Linear(d_in, d_out)\n",
        "    elif depth == 2:\n",
        "        arch = nn.Sequential(\n",
        "            nn.Linear(d_in, width),\n",
        "            (\n",
        "                nn.LayerNorm(width, elementwise_affine=False)\n",
        "                if has_norm\n",
        "                else nn.Identity()\n",
        "            ),\n",
        "            act(),\n",
        "            nn.Linear(2 * width if double_width else width, d_out),\n",
        "        )\n",
        "    else:\n",
        "        in_layer = nn.Linear(d_in, width)\n",
        "        if n_elements > 1:\n",
        "            out_layer = nn.Linear(2 * width if double_width else width, d_out, n_elements)\n",
        "        else:\n",
        "            out_layer = nn.Linear(2 * width if double_width else width, d_out)\n",
        "\n",
        "        hidden = list(\n",
        "            itertools.chain.from_iterable(\n",
        "                [\n",
        "                    [\n",
        "                        (\n",
        "                            nn.LayerNorm(width, elementwise_affine=False)\n",
        "                            if has_norm\n",
        "                            else nn.Identity()\n",
        "                        ),\n",
        "                        act(),\n",
        "                        nn.Linear(2 * width if double_width else width, width),\n",
        "                    ]\n",
        "                    for _ in range(depth - 1)\n",
        "                ]\n",
        "            )\n",
        "        )[:-1]\n",
        "        arch = nn.Sequential(in_layer, *hidden, out_layer)\n",
        "\n",
        "    return arch\n",
        "\n",
        "class ParallelCriticNet(nn.Module):\n",
        "    def __init__(\n",
        "        self, dim_obs, dim_act, depth=3, width=256, act=\"crelu\", has_norm=True\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0] + dim_act[0], 1, depth, width, act=act, has_norm=has_norm\n",
        "        )\n",
        "\n",
        "    def forward(self, xu):\n",
        "        return self.arch(xu)\n",
        "\n",
        "\n",
        "class ActorNet(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim_obs,\n",
        "        dim_act,\n",
        "        depth=3,\n",
        "        width=256,\n",
        "        act=\"crelu\",\n",
        "        has_norm=True,\n",
        "        upper_clamp=None,\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0], dim_act[0], depth, width, act, has_norm\n",
        "        ).append(nn.Tanh())\n",
        "\n",
        "    def forward(self, x, is_training=None):\n",
        "        out = self.arch(x).clamp(-0.9999, 0.9999)\n",
        "        return out, None\n",
        "\n",
        "\n",
        "class SquashedGaussianHead(nn.Module):\n",
        "    def __init__(self, n, upper_clamp=-2.0):\n",
        "        super().__init__()\n",
        "        self._n = n\n",
        "        self._upper_clamp = upper_clamp\n",
        "\n",
        "    def forward(self, x, is_training=True):\n",
        "        # bt means before tanh\n",
        "        mean_bt = x[..., : self._n]\n",
        "        log_var_bt = (x[..., self._n :]).clamp(-10, -self._upper_clamp)  # clamp added\n",
        "        std_bt = log_var_bt.exp().sqrt()\n",
        "        dist_bt = Normal(mean_bt, std_bt)\n",
        "        transform = TanhTransform(cache_size=1)\n",
        "        dist = TransformedDistribution(dist_bt, transform)\n",
        "        if is_training:\n",
        "            y = dist.rsample()\n",
        "            y_logprob = dist.log_prob(y).sum(dim=-1, keepdim=True)\n",
        "        else:\n",
        "            y_samples = dist.rsample((100,))\n",
        "            y = y_samples.mean(dim=0)\n",
        "            y_logprob = None\n",
        "\n",
        "        return y, y_logprob  # dist\n",
        "       \n",
        "    \n",
        "class ActorNet(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim_obs,\n",
        "        dim_act,\n",
        "        depth=3,\n",
        "        width=256,\n",
        "        act=\"crelu\",\n",
        "        has_norm=True,\n",
        "        upper_clamp=None,\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0], dim_act[0], depth, width, act, has_norm\n",
        "        ).append(nn.Tanh())\n",
        "\n",
        "    def forward(self, x, is_training=None):\n",
        "        out = self.arch(x).clamp(-0.9999, 0.9999)\n",
        "        return out, None\n",
        "\n",
        "    \n",
        "\n",
        "class ActorNetProbabilistic(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim_obs,\n",
        "        dim_act,\n",
        "        depth=3,\n",
        "        width=256,\n",
        "        act=\"crelu\",\n",
        "        has_norm=True,\n",
        "        upper_clamp=-2.0,\n",
        "    ):\n",
        "        super().__init__()\n",
        "        self.dim_act = dim_act\n",
        "\n",
        "        self.arch = create_net(dim_obs[0], 2 * dim_act[0], depth, width, act, has_norm)\n",
        "\n",
        "        self.head = SquashedGaussianHead(self.dim_act[0], upper_clamp)\n",
        "\n",
        "    def forward(self, x, is_training=True):\n",
        "        f = self.arch(x)\n",
        "        return self.head(f, is_training)\n",
        "    \n",
        "\n",
        "class CriticNet(nn.Module):\n",
        "    def __init__(\n",
        "        self, dim_obs, dim_act, depth=3, width=256, act=\"crelu\", has_norm=True\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.arch = create_net(\n",
        "            dim_obs[0] + dim_act[0], 1, depth, width, act=act, has_norm=has_norm\n",
        "        )\n",
        "\n",
        "    def forward(self, xu):\n",
        "        return self.arch(xu)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fwb7579n0a1B"
      },
      "source": [
        "# Experience Memory\n",
        "Stores past experiences and allows efficient sampling for training an agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "tn-ca8fx0cL9"
      },
      "outputs": [],
      "source": [
        "class ExperienceMemoryTorch:\n",
        "    \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n",
        "\n",
        "    field_names = [\"state\", \"action\", \"reward\", \"next_state\", \"terminated\", \"step\"]\n",
        "\n",
        "    def __init__(self, args):\n",
        "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self.buffer_size = 100000\n",
        "        self.dims = args.dims\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self, buffer_size=None):\n",
        "        if buffer_size is not None:\n",
        "            self.buffer_size = buffer_size\n",
        "        self.data_size = 0\n",
        "        self.pointer = 0\n",
        "        self.memory = {\n",
        "            field: th.empty(self.dims[field], device=self.device)\n",
        "            for field in self.field_names\n",
        "        }\n",
        "\n",
        "    def add(self, state, action, reward, next_state, terminated, step):\n",
        "        for field, value in zip(\n",
        "            self.field_names, [state, action, reward, next_state, terminated, step]\n",
        "        ):\n",
        "            self.memory[field][self.pointer] = value\n",
        "        self.pointer = (self.pointer + 1) % self.buffer_size\n",
        "        self.data_size = min(self.data_size + 1, self.buffer_size)\n",
        "\n",
        "    def sample_by_index(self, index):\n",
        "        return tuple(self.memory[field][index] for field in self.field_names)\n",
        "\n",
        "    def sample_by_index_fields(self, index, fields):\n",
        "        if len(fields) == 1:\n",
        "            return self.memory[fields[0]][index]  # return a tensor\n",
        "        return tuple(self.memory[field][index] for field in fields)\n",
        "\n",
        "    def sample_random(self, batch_size):\n",
        "        index = th.randint(self.data_size, (batch_size,))\n",
        "        return self.sample_by_index(index)\n",
        "\n",
        "    @staticmethod\n",
        "    def set_diff_1d(t1, t2, assume_unique=False):\n",
        "        \"\"\"\n",
        "        Set difference of two 1D tensors.\n",
        "        Returns the unique values in t1 that are not in t2.\n",
        "        Source: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors/72898627#72898627\n",
        "        \"\"\"\n",
        "        if not assume_unique:\n",
        "            t1 = torch.unique(t1)\n",
        "            t2 = torch.unique(t2)\n",
        "        return t1[(t1[:, None] != t2).all(dim=1)]\n",
        "\n",
        "    def filter_by_nonterminal_steps_with_horizon(self, horizon):\n",
        "        all_indices = th.arange(self.data_size - horizon + 1)\n",
        "        terminal_indices = th.argwhere(self.memory[\"terminated\"] == True)\n",
        "        if terminal_indices.size == 0:\n",
        "            return all_indices\n",
        "\n",
        "        terminal_with_horizon_indices = th.tensor(\n",
        "            [\n",
        "                th.arange(terminal - horizon + 2, terminal + 1)\n",
        "                for terminal in terminal_indices\n",
        "            ]\n",
        "        ).flatten()\n",
        "        nonterminal_indices = th.setdiff1d(all_indices, terminal_with_horizon_indices)\n",
        "        return nonterminal_indices\n",
        "\n",
        "    def sample_random_sequence_snippet(self, batch_size, sequence_length):\n",
        "        non_terminal_indices = self.filter_by_nonterminal_steps_with_horizon(\n",
        "            sequence_length\n",
        "        )\n",
        "        indices = th.randint(non_terminal_indices, (batch_size,))\n",
        "        output = []\n",
        "        # TODO: Why this loop?\n",
        "        for i in range(sequence_length):\n",
        "            output.append(self.sample_by_index(indices + i))\n",
        "        return output\n",
        "\n",
        "    def sample_all(self):\n",
        "        return self.sample_by_index(range(self.data_size))\n",
        "\n",
        "    def clone(self, other_memory):\n",
        "        self.data_size = other_memory.data_size\n",
        "        self.memory = copy.deepcopy(other_memory.memory)\n",
        "\n",
        "    def extend(self, other_memory):\n",
        "        for field in self.field_names:\n",
        "            self.memory[field].extend(other_memory.memory[field])\n",
        "        self.data_size = len(self.memory[field])\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.data_size\n",
        "\n",
        "    @property\n",
        "    def size(self):\n",
        "        return self.data_size\n",
        "\n",
        "    def save(self, path):\n",
        "        th.save(self.memory, os.path.join(path, \"experience_memory.pt\"))\n",
        "\n",
        "    def get_last_observation(self):\n",
        "        return self.sample_by_index([-1])\n",
        "\n",
        "    def get_last_observations(self, batch_size):\n",
        "        return self.sample_by_index(range(-batch_size, 0))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5375-_K6Dv4f"
      },
      "source": [
        "# Agent\n",
        "Defines a reinforcement learning agent that interacts with the environment, stores experiences, and updates its models using soft or hard updates."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "XFQOR6k3DwqW"
      },
      "outputs": [],
      "source": [
        "\n",
        "args, unknown = parser.parse_known_args()\n",
        "\n",
        "def totorch(x, dtype=th.float32, device=\"cpu\"):\n",
        "    return th.as_tensor(x, dtype=dtype, device=device)\n",
        "\n",
        "\n",
        "class Agent(nn.Module):\n",
        "    def __init__(self, env, args):\n",
        "        super(Agent, self).__init__()\n",
        "        self.args = args\n",
        "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        args.buffer_size=100000\n",
        "        self.tau = 0.005\n",
        "        self.gamma = 0.99 \n",
        "        self.env = env\n",
        "        self.dim_obs, self.dim_act = (\n",
        "            self.env.observation_space.shape,\n",
        "            self.env.action_space.shape,\n",
        "        )\n",
        "        print(f\"INFO: dim_obs = {self.dim_obs} dim_act = {self.dim_act}\")\n",
        "        self.dim_obs_flat, self.dim_act_flat = np.prod(self.dim_obs), np.prod(\n",
        "            self.dim_act\n",
        "        )\n",
        "        self._u_min = totorch(self.env.action_space.low, device=self.device)\n",
        "        self._u_max = totorch(self.env.action_space.high, device=self.device)\n",
        "        self._x_min = totorch(self.env.observation_space.low, device=self.device)\n",
        "        self._x_max = totorch(self.env.observation_space.high, device=self.device)\n",
        "\n",
        "        self._gamma = self.gamma\n",
        "        self._tau = self.tau\n",
        "\n",
        "        args.dims = {\n",
        "            \"state\": (args.buffer_size, self.dim_obs_flat),\n",
        "            \"action\": (args.buffer_size, self.dim_act_flat),\n",
        "            \"next_state\": (args.buffer_size, self.dim_obs_flat),\n",
        "            \"reward\": (args.buffer_size),\n",
        "            \"terminated\": (args.buffer_size),\n",
        "            \"step\": (args.buffer_size),\n",
        "        }\n",
        "\n",
        "        self.experience_memory = ExperienceMemoryTorch(args)\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "\n",
        "    def _soft_update(self, local_model, target_model):\n",
        "        for target_param, local_param in zip(\n",
        "            target_model.parameters(), local_model.parameters()\n",
        "        ):\n",
        "            target_param.data.mul_(1.0 - self.args.tau)\n",
        "            target_param.data.add_(self.args.tau * local_param.data)\n",
        "\n",
        "    def _hard_update(self, local_model, target_model):\n",
        "\n",
        "        for target_param, local_param in zip(\n",
        "            target_model.parameters(), local_model.parameters()\n",
        "        ):\n",
        "            target_param.data.copy_(local_param.data)\n",
        "\n",
        "    def learn(self, max_iter=1):\n",
        "        raise NotImplementedError(f\"learn() not implemented for {self.name} agent\")\n",
        "\n",
        "    def select_action(self, warmup=False, exploit=False):\n",
        "        raise NotImplementedError(\n",
        "            f\"select_action() not implemented for {self.name} agent\"\n",
        "        )\n",
        "\n",
        "    def store_transition(self, s, a, r, sp, terminated, truncated, step):\n",
        "        self.experience_memory.add(s, a, r, sp, terminated, step)\n",
        "        self.actor.set_episode_status(terminated or truncated)\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Actor-Critic Agent\n",
        "Creates an actor-critic agent that learns by updating the critic network and adjusting the actor network based on feedback from the critic."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [],
      "source": [
        "class ActorCritic(Agent):\n",
        "    _agent_name = \"AC\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        env,\n",
        "        args,\n",
        "        actor_nn,\n",
        "        critic_nn,\n",
        "        CriticEnsembleType=CriticEnsemble,\n",
        "        ActorType=Actor,\n",
        "    ):\n",
        "        super(ActorCritic, self).__init__(env, args)\n",
        "        self.critics = CriticEnsembleType(critic_nn, args, self.dim_obs, self.dim_act)\n",
        "        self.actor = ActorType(actor_nn, args, self.dim_obs, self.dim_act)\n",
        "        self.n_iter = 0\n",
        "        self.policy_delay = 1\n",
        "        self.args.batch_size = 256\n",
        "\n",
        "    def set_writer(self, writer):\n",
        "        self.writer = writer\n",
        "        self.actor.set_writer(writer)\n",
        "        self.critics.set_writer(writer)\n",
        "\n",
        "    def learn(self, max_iter=1):\n",
        "        if self.args.batch_size > len(self.experience_memory):\n",
        "            return None\n",
        "\n",
        "        for ii in range(max_iter):\n",
        "            s, a, r, sp, done, step = self.experience_memory.sample_random(\n",
        "                self.args.batch_size\n",
        "            )\n",
        "            y = self.critics.get_bellman_target(r, sp, done, self.actor)\n",
        "            self.critics.update(s, a, y)\n",
        "\n",
        "            if self.n_iter % self.policy_delay == 0:\n",
        "                self.actor.update(s, self.critics)\n",
        "            self.critics.update_target()\n",
        "            self.n_iter += 1\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def select_action(self, s, is_training=True):\n",
        "        a, _ = self.actor.act(s, is_training=is_training)\n",
        "        return a\n",
        "\n",
        "    def Q_value(self, s, a):\n",
        "        \n",
        "        if len(s.shape) == 1:\n",
        "            s = s[None]\n",
        "        if len(a.shape) == 1:\n",
        "            a = a[None]\n",
        "        if isinstance(self.critics, Critics ):\n",
        "            self.critics.unstack(target=False, single=True)\n",
        "        q = self.critics[0].Q(s, a)\n",
        "        return q.item()\n",
        "    \n",
        "\n",
        "class SoftActorCritic(ActorCritic):\n",
        "    _agent_name = \"SAC\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        env,\n",
        "        args,\n",
        "        actor_nn=ActorNetProbabilistic,\n",
        "        critic_nn=CriticNet,\n",
        "        CriticEnsembleType=Critics,\n",
        "        ActorType=SoftActor,\n",
        "    ):\n",
        "        super().__init__(env, args, actor_nn, critic_nn, CriticEnsembleType, ActorType)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DXN8_ffH0k0Z"
      },
      "source": [
        "# DEA\n",
        "Our proposed Directional Ensemble Aggregation (DEA) model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ffwBGyfi0lyy"
      },
      "outputs": [],
      "source": [
        "\n",
        "class DEACritic(Critic):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super().__init__(arch, args, n_state, n_action)\n",
        "        self.loss = nn.MSELoss()\n",
        "\n",
        "\n",
        "class DEACritics(Critics):\n",
        "    def __init__(self, arch, args, n_state, n_action, critictype=DEACritic):\n",
        "        super().__init__(arch, args, n_state, n_action, critictype)\n",
        "        self.delta = 1e-4\n",
        "        self.args = args\n",
        "        self.args.evaq_learning_rate = 0.1\n",
        "        self.args.evaq_critic_init = -0.8\n",
        "        self.args.evaq_actor_init = 0.0\n",
        "        self.args.evaq_tau_limit = 0.999\n",
        "\n",
        "        ## (tensors)\n",
        "        self.learning_rate_tau = self.args.evaq_learning_rate\n",
        "        self.tau = torch.tensor(self.args.evaq_critic_init, requires_grad=False, device=args.device)\n",
        "        self.tau_prime = torch.tensor(self.args.evaq_actor_init, requires_grad=False, device=args.device)\n",
        "        \n",
        "        # ## for plots\n",
        "        self.iter_num=[]\n",
        "        self.alpha_list, self.alpha_ep_list = [], []\n",
        "        self.q_mu_pre, self.q_std_pre, self.q_dis_pre = [], [], []\n",
        "        self.q_mu_post, self.q_std_post, self.q_dis_post = [], [], []\n",
        "        self.q_mu_t, self.q_std_t, self.q_dis_t = [], [], []\n",
        "        self.tau_values, self.tau_prime_values = [], []\n",
        "        self.grad_tau_traj, self.grad_tau_prime_traj = [], []\n",
        "    \n",
        "    def pairwise_disagreement(self, q):\n",
        "        # N: ensembles, B: mini-batch size\n",
        "        # q shape: [N, B, 1]\n",
        "        N = q.shape[0]\n",
        "        diffs = q.unsqueeze(0) - q.unsqueeze(1)  # [N, N, B, 1]\n",
        "        abs_diffs = diffs.abs()                 # [N, N, B, 1]\n",
        "        \n",
        "        # Use strictly lower triangular mask: i > j\n",
        "        tril_mask = torch.tril(torch.ones(N, N, dtype=torch.bool, device=q.device), diagonal=-1)\n",
        "        disagreements = abs_diffs[tril_mask]    # [num_pairs, B, 1]\n",
        "        \n",
        "        # Average over all pairs\n",
        "        avg_disagreement = disagreements.mean(dim=0) / 2 # [B, 1]\n",
        "        return avg_disagreement\n",
        "\n",
        "    def get_utility(self, s, a, tau, is_target=False):\n",
        "        q = self.Q_t(s, a) if is_target else self.Q(s, a)\n",
        "        uncerntainty = self.pairwise_disagreement(q)\n",
        "        u = q.mean(dim=0) + tau.tan() * uncerntainty\n",
        "        return u\n",
        "\n",
        "    def U(self, s, a, tau):\n",
        "        return self.get_utility(s, a, tau, is_target=False)  \n",
        "\n",
        "    def U_t(self, s, a, tau):\n",
        "        return self.get_utility(s, a, tau, is_target=True)\n",
        "\n",
        "    def get_utility_target(self, y, tau):\n",
        "        q_t, r, sp, ap, alpha_ep, done = y\n",
        "        u_t = self.U_t(sp, ap, tau)\n",
        "        q_t = r.unsqueeze(-1) + self.args.gamma * (u_t - alpha_ep )  * (1 - done.unsqueeze(-1))\n",
        "        return q_t\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def get_bellman_target(self, r, sp, done, actor, append_alpha=True):\n",
        "        alpha = actor.log_alpha.exp().detach() if hasattr(actor, \"log_alpha\") else 0\n",
        "        ap, ep = actor.act(sp)\n",
        "        qp = self.U_t(sp, ap, self.tau)\n",
        "        if ep is None:\n",
        "            ep = 0\n",
        "        qp_t = qp - alpha * ep\n",
        "        if append_alpha:\n",
        "            self.alpha_list.append(alpha.item())\n",
        "            self.alpha_ep_list.append((alpha*ep).mean().item())\n",
        "        q_t = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))\n",
        "        return [q_t, r, sp, ap, alpha*ep, done]\n",
        "\n",
        "    def update(self, s, a, y):\n",
        "        # Analytical gradient calculation\n",
        "        q_t, r, sp, ap, alpha_ep, done = y\n",
        "        \n",
        "        \n",
        "        if self.iter/self.args.max_iter %2000==0:\n",
        "            q = self.Q(s, a)\n",
        "            mu, sig, dis = q.mean(dim=0), q.std(dim=0), self.pairwise_disagreement(q)\n",
        "            self.q_mu_pre.append(mu.mean().item())\n",
        "            self.q_std_pre.append(sig.mean().item())\n",
        "            self.q_dis_pre.append(dis.mean().item())\n",
        "\n",
        "        \n",
        "        # Critic update\n",
        "        self.optim.zero_grad()\n",
        "        q_t = self.get_utility_target(y, self.tau.clamp(max=0.0))\n",
        "        self.loss(self.Q(s, a), self.expand(q_t)).backward()\n",
        "        self.optim.step()\n",
        "        \n",
        "        if self.iter % self.args.max_iter == 0:\n",
        "            \n",
        "            # tau update\n",
        "            q_t = self.get_utility_target(y, self.tau)\n",
        "            u = self.U(s, a, self.tau_prime).detach()\n",
        "            dis_t = self.pairwise_disagreement(self.Q_t(sp, ap)).detach() \n",
        "            self.grad_tau = - ( (dis_t / dis_t.clamp(min=self.delta)) * (u - q_t).sign()).mean(dim=0)\n",
        "            self.tau  = self.tau - self.learning_rate_tau * self.grad_tau\n",
        "            self.tau = self.tau.clamp(-self.args.evaq_tau_limit, self.args.evaq_tau_limit)\n",
        "\n",
        "            \n",
        "            # tau_prime update\n",
        "            q_t = self.get_utility_target(y, self.tau.clamp(max=0.0))\n",
        "            u = self.U(s, a, self.tau_prime).detach()\n",
        "            dis = self.pairwise_disagreement(self.Q(s, a)).detach() \n",
        "            self.grad_tau_prime = ( (dis / dis.clamp(min=self.delta)) * (u - q_t).sign()).mean(dim=0)\n",
        "            self.tau_prime = self.tau_prime - self.learning_rate_tau * self.grad_tau_prime\n",
        "            self.tau_prime = self.tau_prime.clamp(-self.args.evaq_tau_limit, self.args.evaq_tau_limit) \n",
        "\n",
        "         \n",
        "        if self.iter/self.args.max_iter %2==0:  #2000\n",
        "            self.iter_num.append(self.iter/self.args.max_iter)\n",
        "        \n",
        "        self.iter += 1\n",
        "        return self.iter_num, self.tau_values, self.tau_prime_values, self.q_mu_pre, self.q_std_pre, self.q_mu_t, self.q_std_t, self.alpha_list\n",
        "\n",
        "\n",
        "class DEAActor(SoftActor):\n",
        "    def __init__(self, arch, args, n_state, n_action):\n",
        "        super().__init__(arch, args, n_state, n_action)\n",
        "        self.iter = 1\n",
        "\n",
        "    def loss(self, s, critics):\n",
        "        a, e = self.act(s)\n",
        "        q = critics.U(s, a, critics.tau_prime.clamp(min=0.0))\n",
        "        return (-q + self.log_alpha.exp() * e).mean(), e\n",
        "    \n",
        "\n",
        "\n",
        "class DEASoftActorCritic(SoftActorCritic):\n",
        "    _agent_name = \"DEA\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        env,\n",
        "        args,\n",
        "        actor_nn=ActorNetProbabilistic,\n",
        "        critic_nn=CriticNet,\n",
        "        CriticEnsembleType=DEACritics,\n",
        "        ActorType=DEAActor,\n",
        "    ):\n",
        "        super().__init__(env, args, actor_nn, critic_nn, CriticEnsembleType, ActorType)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x4t4hxd6xNUO"
      },
      "source": [
        "# Control Experiments\n",
        "Manage the training and evaluation of an agent in a control task, and its interactions with the environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X6664Fo4xOqC"
      },
      "outputs": [],
      "source": [
        "class ControlExperiment(Experiment):\n",
        "    def __init__(self):\n",
        "        super(ControlExperiment, self).__init__()\n",
        "        self.eval_reward=0\n",
        "        self.late_eval_reward=0\n",
        "        self.is_break=False\n",
        "        self.device_str = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self.optimizer_args = {\"lr\": 4e-3}\n",
        "        self.n_total_steps = 0\n",
        "        self.args.max_steps = 100000\n",
        "        self.args.eval_frequency=2000\n",
        "        self.args.eval_episodes=10\n",
        "        self.args.gamma=0.99\n",
        "        self.args.warmup_steps=10000\n",
        "        self.args.learn_frequency=1\n",
        "        self.args.max_iter=1\n",
        "        self.args.n_critics=2\n",
        "        self.args.alpha=1\n",
        "        self.args.verbose=False\n",
        "        self.args.progress=False\n",
        "        self.args.reset_frequency=0\n",
        "\n",
        "    def train(self):\n",
        "        time_start = time.time()\n",
        "        information_dict = {\n",
        "            \"episode_rewards\": th.zeros(1000000),\n",
        "            \"episode_steps\": th.zeros(1000000),\n",
        "            \"step_rewards\": np.empty((2 * self.args.max_steps), dtype=object),\n",
        "        }\n",
        "\n",
        "        s, _ = self.env.reset()\n",
        "        s = totorch(s, device=self.args.device)\n",
        "        r_cum = np.zeros(1)\n",
        "        episode = 0\n",
        "        e_step = 0\n",
        "\n",
        "        for step in tqdm(\n",
        "            range(self.args.max_steps), leave=True, disable=not self.args.progress\n",
        "        ):\n",
        "            e_step += 1\n",
        "\n",
        "            if (\n",
        "                step > self.args.warmup_steps\n",
        "                and self.args.reset_frequency > 0\n",
        "                and step % self.args.reset_frequency == 0\n",
        "            ):\n",
        "                self.agent.critics.reset()\n",
        "\n",
        "            if step % self.args.eval_frequency == 0:\n",
        "                self.eval(step)\n",
        "\n",
        "            if step < self.args.warmup_steps:\n",
        "                a = self.env.action_space.sample()\n",
        "                a = totorch(np.clip(a, -1.0, 1.0), device=self.args.device)\n",
        "\n",
        "            else:\n",
        "                a = self.agent.select_action(s).clip(-1.0, 1.0)\n",
        "            sp, r, done, truncated, info = self.env.step(tonumpy(a))\n",
        "            sp = totorch(sp, device=self.args.device)\n",
        "\n",
        "            if self.args.verbose and \"sp\" in self.args.env:\n",
        "                print(\"X pos: \", info[\"x_pos\"], \"Action norm: \", info[\"action_norm\"])\n",
        "                # TODO: Write this instead into a file!\n",
        "\n",
        "            self.agent.store_transition(s, a, r, sp, done, truncated, step + 1)\n",
        "\n",
        "            information_dict[\"step_rewards\"][self.n_total_steps + step] = (\n",
        "                episode,\n",
        "                step,\n",
        "                r,\n",
        "            )\n",
        "\n",
        "            s = sp  # Update state\n",
        "            r_cum += r  # Update cumulative reward\n",
        "\n",
        "            if (\n",
        "                step >= self.args.warmup_steps\n",
        "                and (step % self.args.learn_frequency) == 0\n",
        "            ):\n",
        "                self.agent.learn(max_iter=self.args.max_iter)\n",
        "\n",
        "            if done or truncated:\n",
        "\n",
        "                information_dict[\"episode_rewards\"][episode] = r_cum.item()\n",
        "                information_dict[\"episode_steps\"][episode] = step\n",
        "                print('Episode:', episode, ' Reward: %.3f' % np.mean(r_cum), 'N-steps: %d' % step)\n",
        "                s, _ = self.env.reset()\n",
        "                s = totorch(s, device=self.args.device)\n",
        "                r_cum = np.zeros(1)\n",
        "                episode += 1\n",
        "                e_step = 0\n",
        "\n",
        "\n",
        "        self.eval(step)\n",
        "        time_end = time.time()\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def eval(self, n_step):\n",
        "        self.agent.eval()\n",
        "        results = th.zeros(self.args.eval_episodes)\n",
        "        avg_reward = th.zeros(self.args.eval_episodes)\n",
        "        collect_infos = {}\n",
        "\n",
        "        for episode in range(self.args.eval_episodes):\n",
        "            collect_infos[episode] = []\n",
        "            s, info = self.eval_env.reset()\n",
        "            s = totorch(s, device=self.args.device)\n",
        "            step = 0\n",
        "            a = self.agent.select_action(s, is_training=False)\n",
        "            done = False\n",
        "\n",
        "            while not done:\n",
        "                a = self.agent.select_action(s, is_training=False)\n",
        "                \n",
        "\n",
        "                sp, r, term, trunc, info = self.eval_env.step(tonumpy(a))\n",
        "                collect_infos[episode].append(info)\n",
        "\n",
        "                done = term or trunc\n",
        "                s = totorch(sp, device=self.args.device)\n",
        "                results[episode] += r\n",
        "                avg_reward[episode] += self.args.gamma**step * r\n",
        "                step += 1\n",
        "\n",
        "        self.agent.actor.states = []\n",
        "        self.agent.train()\n",
        "        \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Run All"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 399
        },
        "id": "BRFN2y0505pX",
        "outputId": "3d0fc500-6be0-4910-b140-fe099cde25ed"
      },
      "outputs": [],
      "source": [
        "exp = ControlExperiment()\n",
        "exp.train()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "demo",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.18"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
