{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aea037e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn_extra.cluster import KMedoids\n",
    "\n",
    "from joblib import dump, load\n",
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier, AdaBoostClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn import tree\n",
    "\n",
    "# from arguments.args import get_args\n",
    "# import argparse\n",
    "\n",
    "import random\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11f16faa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load q values\n",
    "file = open('q_values_hopper.pkl', 'rb')\n",
    "data = pickle.load(file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0360b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(q_list):\n",
    "    max_p = max(q_list)\n",
    "    min_p = min(q_list)\n",
    "    alpha = (max_p + min_p) / 2\n",
    "\n",
    "    # Check if max_p is zero\n",
    "    if max_p == 0:\n",
    "        return q_list\n",
    "\n",
    "    for i in range(0, len(q_list)):\n",
    "        if max_p != 0:\n",
    "            q_list[i] = (q_list[i] - alpha) / max_p\n",
    "        else:\n",
    "            # If max_p is zero, set the value to zero\n",
    "            q_list[i] = 0\n",
    "\n",
    "    return q_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04315ccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_extra.cluster import KMedoids\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.model_selection import train_test_split\n",
    "import tempfile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03fa4698",
   "metadata": {},
   "source": [
    "q_list = []\n",
    "tanh_q_list = []\n",
    "obs_list = []\n",
    "\n",
    "for a_i in range(0, 10):\n",
    "    q_val = []\n",
    "    obs = []\n",
    "#     for batch in range(len(data)):\n",
    "    for batch in range(5):\n",
    "        q_val.extend(data[batch]['q_values'][a_i])\n",
    "        obs.extend(np.mean(data[batch]['observations'][a_i], axis=1))\n",
    "    q_list.append(q_val)\n",
    "    obs_list.append(obs)\n",
    " \n",
    "    tanh_q_list.append(np.tanh(normalize(q_val)))\n",
    "state = KMedoids(n_clusters=10, metric='cosine', random_state=0).fit(np.array(q_list).T)\n",
    "labels = state.labels_\n",
    "X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a673530",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "# Define the Encoder model\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_size=11, hidden_size=32, encoded_size=16):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.fc2 = nn.Linear(hidden_size, encoded_size)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "# Define the Decoder model\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, encoded_size=16, hidden_size=32, output_size=1):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.fc1 = nn.Linear(encoded_size, hidden_size)\n",
    "        self.fc2 = nn.Linear(hidden_size, output_size)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "# Define the Encoder-Decoder model\n",
    "class EncoderDecoder(nn.Module):\n",
    "    def __init__(self, input_size=11, hidden_size=32, encoded_size=16, output_size=1):\n",
    "        super(EncoderDecoder, self).__init__()\n",
    "        self.encoder = Encoder(input_size, hidden_size, encoded_size)\n",
    "        self.decoder = Decoder(encoded_size, hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        encoded = self.encoder(x)\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded\n",
    "\n",
    "# Instantiate the model\n",
    "model = EncoderDecoder()\n",
    "\n",
    "# Define a simple dataset\n",
    "# Example input: 11 float numbers\n",
    "example_input = torch.randn((5, 11))  # batch size of 5, 11 float numbers each\n",
    "\n",
    "# Example forward pass\n",
    "output = model(example_input)\n",
    "print(output)\n",
    "\n",
    "# Optionally, define a loss function and an optimizer\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Example training loop\n",
    "# Assume we have input data `inputs` and target data `targets`\n",
    "inputs = torch.randn((100, 11))  # 100 samples of 11 float numbers\n",
    "targets = torch.randn((100, 1))  # 100 target values\n",
    "\n",
    "for epoch in range(100):  # Number of epochs\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    outputs = model(inputs)\n",
    "    loss = criterion(outputs, targets)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (epoch+1) % 10 == 0:\n",
    "        print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d038b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.encoder(torch.from_numpy(data[batch]['observations'][0][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea5a70a2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1434e518",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_list = []\n",
    "tanh_q_list = []\n",
    "obs_list = []\n",
    "\n",
    "num_batches = len(data)\n",
    "num_actions = 10\n",
    "sample_size = 10000  # Adjust sample size based on available memory\n",
    "\n",
    "for a_i in range(num_actions):\n",
    "    q_val = []\n",
    "    obs = []\n",
    "\n",
    "    for batch in range(num_batches):\n",
    "        q_val.extend(data[batch]['q_values'][a_i])\n",
    "#         obs.extend(np.mean(data[batch]['observations'][a_i], axis=1))\n",
    "#         obs.extend(data[batch]['observations'][a_i])\n",
    "        obs.extend(model.encoder(torch.from_numpy(data[batch]['observations'][a_i])).detach().numpy()) \n",
    "    q_list.append(q_val)\n",
    "    obs_list.append(obs)\n",
    "    tanh_q_list.append(np.tanh(normalize(q_val)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "588bdabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(q_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43ad8622",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_list[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a78d793d",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_array = np.array(q_list).T\n",
    "obs_array = np.array(obs_list).T\n",
    "\n",
    "\n",
    "pca = PCA(n_components=min(20, q_array.shape[1]))  \n",
    "q_array_reduced = pca.fit_transform(q_array)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e419579b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "indices = np.random.choice(q_array_reduced.shape[0], sample_size, replace=False)\n",
    "q_array_sampled = q_array_reduced[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c85cdf6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "state = KMedoids(n_clusters=10, metric='cosine', random_state=0).fit(q_array_sampled)\n",
    "labels = np.full(q_array_reduced.shape[0], -1)\n",
    "labels[indices] = state.labels_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "567e6779",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c23fcd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "knn = KNeighborsClassifier(n_neighbors=10, metric='cosine')\n",
    "knn.fit(q_array_sampled, state.labels_)\n",
    "\n",
    "full_labels = knn.predict(q_array_reduced)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3510cfb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(obs_array[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "485fa9e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(full_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f38f923c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a807552",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the data\n",
    "X_train, X_test, y_train, y_test = train_test_split(obs_array, full_labels, test_size=0.2, random_state=42)\n",
    "\n",
    "# X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, full_labels, test_size=0.2, random_state=42)\n",
    "\n",
    "# X_train, X_test, y_train, y_test = train_test_split(obs_list, full_labels, test_size=0.2, random_state=42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b46c77",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_bs, X_test_bs, y_train_bs, y_test_bs = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acd1c839",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff380b9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(X_train_bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d301408f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "stats = defaultdict(int)\n",
    "for i in labels:\n",
    "    stats[i] += 1\n",
    "\n",
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea74c33",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "stats = defaultdict(int)\n",
    "for i in full_labels:\n",
    "    stats[i] += 1\n",
    "\n",
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0147f3c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ae45ab3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e4c6348",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "stateAdaBoostClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700d128a",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(obs_list).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c7155f",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "514435c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import HistGradientBoostingClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6d19735",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# train\n",
    "tree_with_svm = HistGradientBoostingClassifier(max_iter=1000,max_leaf_nodes=128)\n",
    "tree_with_svm.fit(X_train_bs, y_train_bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35490d4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "tree_with_svm = GradientBoostingClassifier(max_leaf_nodes=40)\n",
    "tree_with_svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70841892",
   "metadata": {},
   "source": [
    "# save model\n",
    "dump(tree_with_svm, os.path.join('.', 'tree_with_svm.joblib'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd0d44d",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "efe33b19",
   "metadata": {},
   "source": [
    "# test\n",
    "tree_with_svm = load(os.path.join('.', 'tree_with_svm_hop.joblib'))           \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86a2c1f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = tree_with_svm.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "646154f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_array = []\n",
    "for first_action in np.arange(-0.90, 0.901, 0.2):\n",
    "    action_array.append(first_action)\n",
    "action_array = np.array(action_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6a0921",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_array[6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43727a27",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2148d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = tree_with_svm.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5afcab17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "import pickle\n",
    "import random\n",
    "\n",
    "import d4rl\n",
    "import gym\n",
    "import numpy as np\n",
    "import pyrallis\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "from torch.distributions import Normal\n",
    "from tqdm import trange\n",
    "\n",
    "\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import uuid\n",
    "from copy import deepcopy\n",
    "from dataclasses import asdict, dataclass\n",
    "from typing import Any, Dict, List, Optional, Tuple, Union\n",
    "\n",
    "TensorBatch = List[torch.Tensor]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f07a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "project: str = \"CORL\"\n",
    "group: str = \"SAC-N\"\n",
    "name: str = \"SAC-N\"\n",
    "# model params\n",
    "hidden_dim: int = 256\n",
    "num_critics: int = 200\n",
    "    \n",
    "# CHANGE NUM_CRITICS    # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   # CHANGE NUM_CRITICS   \n",
    "# 10 for halfcheetah 200 for hopper\n",
    "gamma: float = 0.99\n",
    "tau: float = 5e-3\n",
    "actor_learning_rate: float = 3e-4\n",
    "critic_learning_rate: float = 3e-4\n",
    "alpha_learning_rate: float = 3e-4\n",
    "max_action: float = 1.0\n",
    "# training params\n",
    "buffer_size: int = 2_000_000\n",
    "env_name: str = \"hopper-expert-v2\"\n",
    "batch_size: int = 256\n",
    "num_epochs: int = 3000\n",
    "num_updates_on_epoch: int = 1000\n",
    "normalize_reward: bool = False\n",
    "# evaluation params\n",
    "eval_episodes: int = 10\n",
    "eval_every: int = 20\n",
    "# general params\n",
    "checkpoints_path: Optional[str] = \"./checkpoints\"\n",
    "deterministic_torch: bool = False\n",
    "train_seed: int = 10\n",
    "eval_seed: int = 42\n",
    "log_every: int = 100\n",
    "device: str = \"cuda\"\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97f4d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrap_env(\n",
    "    env: gym.Env,\n",
    "    state_mean: Union[np.ndarray, float] = 0.0,\n",
    "    state_std: Union[np.ndarray, float] = 1.0,\n",
    "    reward_scale: float = 1.0,\n",
    ") -> gym.Env:\n",
    "    def normalize_state(state):\n",
    "        return (state - state_mean) / state_std\n",
    "\n",
    "    def scale_reward(reward):\n",
    "        return reward_scale * reward\n",
    "\n",
    "    env = gym.wrappers.TransformObservation(env, normalize_state)\n",
    "    if reward_scale != 1.0:\n",
    "        env = gym.wrappers.TransformReward(env, scale_reward)\n",
    "    return env\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a313b004",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acdc2e0d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4591519",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8b4574",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fca161c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8438ede",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff87471",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VectorizedLinear(nn.Module):\n",
    "    def __init__(self, in_features: int, out_features: int, ensemble_size: int):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.ensemble_size = ensemble_size\n",
    "\n",
    "        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))\n",
    "        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        # default pytorch init for nn.Linear module\n",
    "        for layer in range(self.ensemble_size):\n",
    "            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))\n",
    "\n",
    "        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])\n",
    "        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n",
    "        nn.init.uniform_(self.bias, -bound, bound)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        # input: [ensemble_size, batch_size, input_size]\n",
    "        # weight: [ensemble_size, input_size, out_size]\n",
    "        # out: [ensemble_size, batch_size, out_size]\n",
    "        return x @ self.weight + self.bias\n",
    "\n",
    "\n",
    "class Actor(nn.Module):\n",
    "    def __init__(\n",
    "        self, state_dim: int, action_dim: int, hidden_dim: int, max_action: float = 1.0\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.trunk = nn.Sequential(\n",
    "            nn.Linear(state_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        # with separate layers works better than with Linear(hidden_dim, 2 * action_dim)\n",
    "        self.mu = nn.Linear(hidden_dim, action_dim)\n",
    "        self.log_sigma = nn.Linear(hidden_dim, action_dim)\n",
    "\n",
    "        # init as in the EDAC paper\n",
    "        for layer in self.trunk[::2]:\n",
    "            torch.nn.init.constant_(layer.bias, 0.1)\n",
    "\n",
    "        torch.nn.init.uniform_(self.mu.weight, -1e-3, 1e-3)\n",
    "        torch.nn.init.uniform_(self.mu.bias, -1e-3, 1e-3)\n",
    "        torch.nn.init.uniform_(self.log_sigma.weight, -1e-3, 1e-3)\n",
    "        torch.nn.init.uniform_(self.log_sigma.bias, -1e-3, 1e-3)\n",
    "\n",
    "        self.action_dim = action_dim\n",
    "        self.max_action = max_action\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        state: torch.Tensor,\n",
    "        deterministic: bool = False,\n",
    "        need_log_prob: bool = False,\n",
    "    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n",
    "        hidden = self.trunk(state)\n",
    "        mu, log_sigma = self.mu(hidden), self.log_sigma(hidden)\n",
    "\n",
    "        # clipping params from EDAC paper, not as in SAC paper (-20, 2)\n",
    "        log_sigma = torch.clip(log_sigma, -5, 2)\n",
    "        policy_dist = Normal(mu, torch.exp(log_sigma))\n",
    "\n",
    "        if deterministic:\n",
    "            action = mu\n",
    "        else:\n",
    "            action = policy_dist.rsample()\n",
    "\n",
    "        tanh_action, log_prob = torch.tanh(action), None\n",
    "        if need_log_prob:\n",
    "            # change of variables formula (SAC paper, appendix C, eq 21)\n",
    "            log_prob = policy_dist.log_prob(action).sum(axis=-1)\n",
    "            log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1)\n",
    "\n",
    "        return tanh_action * self.max_action, log_prob\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def act(self, state: np.ndarray, device: str) -> np.ndarray:\n",
    "        deterministic = not self.training\n",
    "        state = torch.tensor(state, device=device, dtype=torch.float32)\n",
    "        action = self(state, deterministic=deterministic)[0].cpu().numpy()\n",
    "        return action\n",
    "\n",
    "\n",
    "class VectorizedCritic(nn.Module):\n",
    "    def __init__(\n",
    "        self, state_dim: int, action_dim: int, hidden_dim: int, num_critics: int\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.critic = nn.Sequential(\n",
    "            VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics),\n",
    "            nn.ReLU(),\n",
    "            VectorizedLinear(hidden_dim, hidden_dim, num_critics),\n",
    "            nn.ReLU(),\n",
    "            VectorizedLinear(hidden_dim, hidden_dim, num_critics),\n",
    "            nn.ReLU(),\n",
    "            VectorizedLinear(hidden_dim, 1, num_critics),\n",
    "        )\n",
    "        # init as in the EDAC paper\n",
    "        for layer in self.critic[::2]:\n",
    "            torch.nn.init.constant_(layer.bias, 0.1)\n",
    "\n",
    "        torch.nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)\n",
    "        torch.nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)\n",
    "\n",
    "        self.num_critics = num_critics\n",
    "\n",
    "    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:\n",
    "        # [batch_size, state_dim + action_dim]\n",
    "        state_action = torch.cat([state, action], dim=-1)\n",
    "        # [num_critics, batch_size, state_dim + action_dim]\n",
    "        state_action = state_action.unsqueeze(0).repeat_interleave(\n",
    "            self.num_critics, dim=0\n",
    "        )\n",
    "        # [num_critics, batch_size]\n",
    "        q_values = self.critic(state_action).squeeze(-1)\n",
    "        return q_values\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63a348f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SACN:\n",
    "    def __init__(\n",
    "        self,\n",
    "        actor: Actor,\n",
    "        actor_optimizer: torch.optim.Optimizer,\n",
    "        critic: VectorizedCritic,\n",
    "        critic_optimizer: torch.optim.Optimizer,\n",
    "        gamma: float = 0.99,\n",
    "        tau: float = 0.005,\n",
    "        alpha_learning_rate: float = 1e-4,\n",
    "        device: str = \"cpu\",\n",
    "    ):\n",
    "        self.device = device\n",
    "\n",
    "        self.actor = actor\n",
    "        self.critic = critic\n",
    "        with torch.no_grad():\n",
    "            self.target_critic = deepcopy(self.critic)\n",
    "\n",
    "        self.actor_optimizer = actor_optimizer\n",
    "        self.critic_optimizer = critic_optimizer\n",
    "\n",
    "        self.tau = tau\n",
    "        self.gamma = gamma\n",
    "\n",
    "        # adaptive alpha setup\n",
    "        self.target_entropy = -float(self.actor.action_dim)\n",
    "        self.log_alpha = torch.tensor(\n",
    "            [0.0], dtype=torch.float32, device=self.device, requires_grad=True\n",
    "        )\n",
    "        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_learning_rate)\n",
    "        self.alpha = self.log_alpha.exp().detach()\n",
    "\n",
    "    def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor:\n",
    "        with torch.no_grad():\n",
    "            action, action_log_prob = self.actor(state, need_log_prob=True)\n",
    "\n",
    "        loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean()\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def _actor_loss(self, state: torch.Tensor) -> Tuple[torch.Tensor, float, float]:\n",
    "        action, action_log_prob = self.actor(state, need_log_prob=True)\n",
    "        q_value_dist = self.critic(state, action)\n",
    "        assert q_value_dist.shape[0] == self.critic.num_critics\n",
    "        q_value_min = q_value_dist.min(0).values\n",
    "        # needed for logging\n",
    "        q_value_std = q_value_dist.std(0).mean().item()\n",
    "        batch_entropy = -action_log_prob.mean().item()\n",
    "\n",
    "        assert action_log_prob.shape == q_value_min.shape\n",
    "        loss = (self.alpha * action_log_prob - q_value_min).mean()\n",
    "\n",
    "        return loss, batch_entropy, q_value_std\n",
    "\n",
    "    def _critic_loss(\n",
    "        self,\n",
    "        state: torch.Tensor,\n",
    "        action: torch.Tensor,\n",
    "        reward: torch.Tensor,\n",
    "        next_state: torch.Tensor,\n",
    "        done: torch.Tensor,\n",
    "    ) -> torch.Tensor:\n",
    "        with torch.no_grad():\n",
    "            next_action, next_action_log_prob = self.actor(\n",
    "                next_state, need_log_prob=True\n",
    "            )\n",
    "            q_next = self.target_critic(next_state, next_action).min(0).values\n",
    "            q_next = q_next - self.alpha * next_action_log_prob\n",
    "\n",
    "            assert q_next.unsqueeze(-1).shape == done.shape == reward.shape\n",
    "            q_target = reward + self.gamma * (1 - done) * q_next.unsqueeze(-1)\n",
    "\n",
    "        q_values = self.critic(state, action)\n",
    "        # [ensemble_size, batch_size] - [1, batch_size]\n",
    "        loss = ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def update(self, batch: TensorBatch) -> Dict[str, float]:\n",
    "        state, action, reward, next_state, done = [arr.to(self.device) for arr in batch]\n",
    "        # Usually updates are done in the following order: critic -> actor -> alpha\n",
    "        # But we found that EDAC paper uses reverse (which gives better results)\n",
    "\n",
    "        # Alpha update\n",
    "        alpha_loss = self._alpha_loss(state)\n",
    "        self.alpha_optimizer.zero_grad()\n",
    "        alpha_loss.backward()\n",
    "        self.alpha_optimizer.step()\n",
    "\n",
    "        self.alpha = self.log_alpha.exp().detach()\n",
    "\n",
    "        # Actor update\n",
    "        actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss(state)\n",
    "        self.actor_optimizer.zero_grad()\n",
    "        actor_loss.backward()\n",
    "        self.actor_optimizer.step()\n",
    "\n",
    "        # Critic update\n",
    "        critic_loss = self._critic_loss(state, action, reward, next_state, done)\n",
    "        self.critic_optimizer.zero_grad()\n",
    "        critic_loss.backward()\n",
    "        self.critic_optimizer.step()\n",
    "\n",
    "        #  Target networks soft update\n",
    "        with torch.no_grad():\n",
    "            soft_update(self.target_critic, self.critic, tau=self.tau)\n",
    "            # for logging, Q-ensemble std estimate with the random actions:\n",
    "            # a ~ U[-max_action, max_action]\n",
    "            max_action = self.actor.max_action\n",
    "            random_actions = -max_action + 2 * max_action * torch.rand_like(action)\n",
    "\n",
    "            q_random_std = self.critic(state, random_actions).std(0).mean().item()\n",
    "\n",
    "        update_info = {\n",
    "            \"alpha_loss\": alpha_loss.item(),\n",
    "            \"critic_loss\": critic_loss.item(),\n",
    "            \"actor_loss\": actor_loss.item(),\n",
    "            \"batch_entropy\": actor_batch_entropy,\n",
    "            \"alpha\": self.alpha.item(),\n",
    "            \"q_policy_std\": q_policy_std,\n",
    "            \"q_random_std\": q_random_std,\n",
    "        }\n",
    "        return update_info\n",
    "\n",
    "    def state_dict(self) -> Dict[str, Any]:\n",
    "        state = {\n",
    "            \"actor\": self.actor.state_dict(),\n",
    "            \"critic\": self.critic.state_dict(),\n",
    "            \"target_critic\": self.target_critic.state_dict(),\n",
    "            \"log_alpha\": self.log_alpha.item(),\n",
    "            \"actor_optim\": self.actor_optimizer.state_dict(),\n",
    "            \"critic_optim\": self.critic_optimizer.state_dict(),\n",
    "            \"alpha_optim\": self.alpha_optimizer.state_dict(),\n",
    "        }\n",
    "        return state\n",
    "\n",
    "    def load_state_dict(self, state_dict: Dict[str, Any]):\n",
    "        self.actor.load_state_dict(state_dict[\"actor\"])\n",
    "        self.critic.load_state_dict(state_dict[\"critic\"])\n",
    "        self.target_critic.load_state_dict(state_dict[\"target_critic\"])\n",
    "        self.actor_optimizer.load_state_dict(state_dict[\"actor_optim\"])\n",
    "        self.critic_optimizer.load_state_dict(state_dict[\"critic_optim\"])\n",
    "        self.alpha_optimizer.load_state_dict(state_dict[\"alpha_optim\"])\n",
    "        self.log_alpha.data[0] = state_dict[\"log_alpha\"]\n",
    "        self.alpha = self.log_alpha.exp().detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c05e3428",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    def __init__(\n",
    "        self,\n",
    "        state_dim: int,\n",
    "        action_dim: int,\n",
    "        buffer_size: int,\n",
    "        device: str = \"cpu\",\n",
    "    ):\n",
    "        self._buffer_size = buffer_size\n",
    "        self._pointer = 0\n",
    "        self._size = 0\n",
    "\n",
    "        self._states = torch.zeros(\n",
    "            (buffer_size, state_dim), dtype=torch.float32, device=device\n",
    "        )\n",
    "        self._actions = torch.zeros(\n",
    "            (buffer_size, action_dim), dtype=torch.float32, device=device\n",
    "        )\n",
    "        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)\n",
    "        self._next_states = torch.zeros(\n",
    "            (buffer_size, state_dim), dtype=torch.float32, device=device\n",
    "        )\n",
    "        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)\n",
    "        self._device = device\n",
    "\n",
    "    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:\n",
    "        return torch.tensor(data, dtype=torch.float32, device=self._device)\n",
    "\n",
    "    # Loads data in d4rl format, i.e. from Dict[str, np.array].\n",
    "    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):\n",
    "        if self._size != 0:\n",
    "            raise ValueError(\"Trying to load data into non-empty replay buffer\")\n",
    "        n_transitions = data[\"observations\"].shape[0]\n",
    "        if n_transitions > self._buffer_size:\n",
    "            raise ValueError(\n",
    "                \"Replay buffer is smaller than the dataset you are trying to load!\"\n",
    "            )\n",
    "        self._states[:n_transitions] = self._to_tensor(data[\"observations\"])\n",
    "        self._actions[:n_transitions] = self._to_tensor(data[\"actions\"])\n",
    "        self._rewards[:n_transitions] = self._to_tensor(data[\"rewards\"][..., None])\n",
    "        self._next_states[:n_transitions] = self._to_tensor(data[\"next_observations\"])\n",
    "        self._dones[:n_transitions] = self._to_tensor(data[\"terminals\"][..., None])\n",
    "        self._size += n_transitions\n",
    "        self._pointer = min(self._size, n_transitions)\n",
    "\n",
    "        print(f\"Dataset size: {n_transitions}\")\n",
    "\n",
    "    def sample(self, batch_size: int) -> TensorBatch:\n",
    "        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)\n",
    "        states = self._states[indices]\n",
    "        actions = self._actions[indices]\n",
    "        rewards = self._rewards[indices]\n",
    "        next_states = self._next_states[indices]\n",
    "        dones = self._dones[indices]\n",
    "        return [states, actions, rewards, next_states, dones]\n",
    "\n",
    "    def add_transition(self):\n",
    "        # Use this method to add new data into the replay buffer during fine-tuning.\n",
    "        raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fc2b10f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data, evaluation, env setup\n",
    "eval_env = wrap_env(gym.make(env_name))\n",
    "state_dim = eval_env.observation_space.shape[0]\n",
    "action_dim = eval_env.action_space.shape[0]\n",
    "\n",
    "d4rl_dataset = d4rl.qlearning_dataset(eval_env)\n",
    "\n",
    "if normalize_reward:\n",
    "    modify_reward(d4rl_dataset, env_name)\n",
    "\n",
    "buffer = ReplayBuffer(\n",
    "    state_dim=state_dim,\n",
    "    action_dim=action_dim,\n",
    "    buffer_size=buffer_size,\n",
    "    device=device,\n",
    ")\n",
    "buffer.load_d4rl_dataset(d4rl_dataset)\n",
    "\n",
    "# Actor & Critic setup\n",
    "actor = Actor(state_dim, action_dim, hidden_dim, max_action)\n",
    "actor.to(device)\n",
    "actor_optimizer = torch.optim.Adam(actor.parameters(), lr=actor_learning_rate)\n",
    "critic = VectorizedCritic(\n",
    "    state_dim, action_dim, hidden_dim, num_critics\n",
    ")\n",
    "critic.to(device)\n",
    "critic_optimizer = torch.optim.Adam(\n",
    "    critic.parameters(), lr=critic_learning_rate\n",
    ")\n",
    "\n",
    "trainer = SACN(\n",
    "    actor=actor,\n",
    "    actor_optimizer=actor_optimizer,\n",
    "    critic=critic,\n",
    "    critic_optimizer=critic_optimizer,\n",
    "    gamma=gamma,\n",
    "    tau=tau,\n",
    "    alpha_learning_rate=alpha_learning_rate,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8903f376",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(\"./SAC-N-hopper-medium-expert-v2-ac25f151/2900.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19964f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.load_state_dict(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9f46d6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "831b908e",
   "metadata": {},
   "source": [
    "# SAC_N"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da85f8bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(episode_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d140d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5506e1c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc25f77f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae9b3b5b",
   "metadata": {},
   "source": [
    "# OURS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e606596c",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "for _ in range(eval_episodes):\n",
    "    state, done = eval_env.reset(), False\n",
    "    episode_reward = 0.0\n",
    "    while not done:\n",
    "        action = trainer.actor.act(state, device)\n",
    "        predict_label = tree_with_svm.predict([np.repeat(state.mean(), 10)])\n",
    "        if random.random() > 0.0:\n",
    "            action_0 = action_array[predict_label]\n",
    "            action[0] = action_0\n",
    "        \n",
    "        state, reward, done, _ = eval_env.step(action)\n",
    "        episode_reward += reward\n",
    "    episode_rewards.append(episode_reward)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "123c3961",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(episode_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd4aeab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75fdd8da",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(r\"tree_with_svm.pkl\", \"wb\") as output_file:\n",
    "    pickle.dump(tree_with_svm, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09e21b3d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "214483dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a390c80",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_array[6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcfbf2b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54ab3481",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "583e1249",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bd266af",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c3e6264",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_stats = defaultdict(int)\n",
    "for i in pred:\n",
    "    action_stats[i] += 1\n",
    "\n",
    "action_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c020001",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_bs, X_test_bs, y_train_bs, y_test_bs = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e236be8",
   "metadata": {},
   "source": [
    "# DecisionTreeClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f36ce89",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_with_cart = DecisionTreeClassifier(max_leaf_nodes=40)\n",
    "tree_with_cart.fit(X_train_bs, y_train_bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a99ced49",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "for _ in range(eval_episodes):\n",
    "    state, done = eval_env.reset(), False\n",
    "    episode_reward = 0.0\n",
    "    while not done:\n",
    "        action = trainer.actor.act(state, device)\n",
    "        predict_label = tree_with_cart.predict([np.repeat(state.mean(), 10)])\n",
    "        if random.random() > 0.0:\n",
    "            action_0 = action_array[predict_label]\n",
    "            action[0] = action_0\n",
    "        \n",
    "        state, reward, done, _ = eval_env.step(action)\n",
    "        episode_reward += reward\n",
    "    episode_rewards.append(episode_reward)\n",
    "print(episode_rewards)\n",
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0c107cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(r\"tree_with_cart.pkl\", \"wb\") as output_file:\n",
    "    pickle.dump(tree_with_cart, output_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a3c96f",
   "metadata": {},
   "source": [
    "# RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "107e3350",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_with_rf = RandomForestClassifier(max_leaf_nodes=40)\n",
    "tree_with_rf.fit(X_train_bs, y_train_bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1478e6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "for _ in range(eval_episodes):\n",
    "    state, done = eval_env.reset(), False\n",
    "    episode_reward = 0.0\n",
    "    while not done:\n",
    "        action = trainer.actor.act(state, device)\n",
    "        predict_label = tree_with_rf.predict([np.repeat(state.mean(), 10)])\n",
    "        if random.random() > 0.0:\n",
    "            action_0 = action_array[predict_label]\n",
    "            action[0] = action_0\n",
    "        \n",
    "        state, reward, done, _ = eval_env.step(action)\n",
    "        episode_reward += reward\n",
    "    episode_rewards.append(episode_reward)\n",
    "print(episode_rewards)\n",
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ebe4148",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(r\"tree_with_rf.pkl\", \"wb\") as output_file:\n",
    "    pickle.dump(tree_with_rf, output_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "137b1fea",
   "metadata": {},
   "source": [
    "# ExtraTreesClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "184f3211",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_with_et = ExtraTreesClassifier(max_leaf_nodes=128)\n",
    "tree_with_et.fit(X_train_bs, y_train_bs)\n",
    "# tree_with_et.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8244747",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "for _ in range(eval_episodes):\n",
    "    state, done = eval_env.reset(), False\n",
    "    episode_reward = 0.0\n",
    "    while not done:\n",
    "        action = trainer.actor.act(state, device)\n",
    "        predict_label = tree_with_et.predict([np.repeat(state.mean(), 10)])\n",
    "        if random.random() > 0.0:\n",
    "            action_0 = action_array[predict_label]\n",
    "            action[0] = action_0\n",
    "        \n",
    "        state, reward, done, _ = eval_env.step(action)\n",
    "        episode_reward += reward\n",
    "    episode_rewards.append(episode_reward)\n",
    "print(episode_rewards)\n",
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26518a8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05f3c8e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(r\"tree_with_et.pkl\", \"wb\") as output_file:\n",
    "    pickle.dump(tree_with_et, output_file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "161f0e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(r\"tree_with_et.pkl\", \"rb\") as output_file:\n",
    "    tree_with_et = pickle.load(output_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f9ba6f3",
   "metadata": {},
   "source": [
    "# What if all 0 input?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b9ee145",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "for _ in range(eval_episodes):\n",
    "    state, done = eval_env.reset(), False\n",
    "    episode_reward = 0.0\n",
    "    while not done:\n",
    "        action = trainer.actor.act(state, device)\n",
    "        predict_label = tree_with_et.predict([np.repeat(state.mean(), 10)])\n",
    "        if random.random() > 0.0:\n",
    "            \n",
    "            action *= 0\n",
    "        \n",
    "        state, reward, done, _ = eval_env.step(action)\n",
    "        episode_reward += reward\n",
    "    episode_rewards.append(episode_reward)\n",
    "print(episode_rewards)\n",
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "facf459e",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "for _ in range(eval_episodes):\n",
    "    state, done = eval_env.reset(), False\n",
    "    episode_reward = 0.0\n",
    "    while not done:\n",
    "        action = trainer.actor.act(state, device)\n",
    "        predict_label = tree_with_et.predict([np.repeat(state.mean(), 10)])\n",
    "        if random.random() > 0.0:\n",
    "            \n",
    "            action[0]= action[0] * 0 + random.random()*2-1\n",
    "        \n",
    "        state, reward, done, _ = eval_env.step(action)\n",
    "        episode_reward += reward\n",
    "    episode_rewards.append(episode_reward)\n",
    "print(episode_rewards)\n",
    "print(eval_env.get_normalized_score(np.array(episode_rewards)) * 100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2f46622",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f07812b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
