{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f40608cb-eb7c-45f8-882d-100bbaaf8f21",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ab0f2e61-c897-492b-8def-e34f2009a56e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "#!! do not import matplotlib until you check input arguments\n",
    "import numpy as np\n",
    "import os\n",
    "import seeding\n",
    "import sys\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import logging\n",
    "import shutil\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import copy\n",
    "import pandas as pd\n",
    "import os\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "import argparse\n",
    "import glob\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ebe702ce-9091-4357-a40f-3c933b0c9d91",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_parser():\n",
    "    return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "\n",
    "def load_experiment(tag, coefs=None):\n",
    "    logfiles = sorted(glob.glob(os.path.join('results/logs', tag + '*', 'train-*.txt')))\n",
    "    seeds = [f.split('-')[-1].split('.')[0] for f in logfiles]\n",
    "    logs = [open(f, 'r').read().splitlines() for f in logfiles]\n",
    "\n",
    "    def read_log(log, coefs=coefs):\n",
    "        results = [json.loads(item) for item in log]\n",
    "        fields = results[0].keys()\n",
    "        data = dict([(f, np.asarray([item[f] for item in results])) for f in fields])\n",
    "        if coefs is None:\n",
    "            coefs = {\n",
    "                'L_inv': 1.0,\n",
    "                'L_fwd': 0.1,\n",
    "                'L_cpc': 1.0,\n",
    "                'L_fac': 0.1,\n",
    "            }\n",
    "        if 'L' not in fields:\n",
    "            data['L'] = sum([\n",
    "                coefs[f] * data[f] if f != 'L_fac' else coefs[f] * (data[f] - 1)\n",
    "                for f in coefs.keys()\n",
    "            ])\n",
    "        return data\n",
    "\n",
    "    results = [read_log(log) for log in logs]\n",
    "    data = dict(zip(seeds, results))\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f6b1e68a-b021-4092-9b1f-7bae020e3e6f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Network(torch.nn.Module):\n",
    "    \"\"\"Module that, when printed, shows its total number of parameters\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.frozen = False\n",
    "\n",
    "    def __str__(self):\n",
    "        s = super().__str__() + '\\n'\n",
    "        n_params = 0\n",
    "        for p in self.parameters():\n",
    "            n_params += np.prod(p.size())\n",
    "        s += 'Total params: {}'.format(n_params)\n",
    "        return s\n",
    "\n",
    "    def print_summary(self):\n",
    "        s = str(self)\n",
    "        print(s)\n",
    "\n",
    "    def save(self, name, model_dir, is_best=False):\n",
    "        os.makedirs(model_dir, exist_ok=True)\n",
    "        model_file = os.path.join(model_dir, '{}_latest.pytorch'.format(name))\n",
    "        torch.save(self.state_dict(), model_file)\n",
    "        logging.info('Model saved to {}'.format(model_file))\n",
    "        if is_best:\n",
    "            best_file = os.path.join(model_dir, '{}_best.pytorch'.format(name))\n",
    "            shutil.copyfile(model_file, best_file)\n",
    "            logging.info('New best model! Model copied to {}'.format(best_file))\n",
    "\n",
    "    def load(self, model_file, force_cpu=False):\n",
    "        logging.info('Loading model from {}...'.format(model_file))\n",
    "        map_loc = 'cpu' if force_cpu else None\n",
    "        state_dict = torch.load(model_file, map_location=map_loc)\n",
    "        self.load_state_dict(state_dict)\n",
    "\n",
    "    def freeze(self):\n",
    "        if not self.frozen:\n",
    "            for param in self.parameters():\n",
    "                param.requires_grad = False\n",
    "            self.frozen = True\n",
    "\n",
    "    def unfreeze(self):\n",
    "        if self.frozen:\n",
    "            for param in self.parameters():\n",
    "                param.requires_grad = True\n",
    "            self.frozen = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0c4c70d8-d81c-44ec-8793-7a00ac76e611",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class AutoEncoder(Network):\n",
    "    def __init__(self,\n",
    "                 n_actions,\n",
    "                 input_shape=2,\n",
    "                 n_latent_dims=4,\n",
    "                 n_hidden_layers=1,\n",
    "                 n_units_per_layer=32,\n",
    "                 lr=0.001,\n",
    "                 coefs=None):\n",
    "        super().__init__()\n",
    "        self.n_actions = n_actions\n",
    "        self.n_latent_dims = n_latent_dims\n",
    "        self.lr = lr\n",
    "        self.coefs = defaultdict(lambda: 1.0)\n",
    "        self.phi = PhiNet(input_shape=input_shape,\n",
    "                          n_latent_dims=n_latent_dims,\n",
    "                          n_units_per_layer=n_units_per_layer,\n",
    "                          n_hidden_layers=n_hidden_layers)\n",
    "        self.reverse_phi = PhiNet(input_shape=input_shape,\n",
    "                                  n_latent_dims=n_latent_dims,\n",
    "                                  n_units_per_layer=n_units_per_layer,\n",
    "                                  n_hidden_layers=n_hidden_layers)\n",
    "        self.reverse_phi.phi = nn.Sequential(\n",
    "            *reversed([Reshape(-1, *input_shape), nn.Tanh()] + [\n",
    "                nn.Linear(l.out_features, l.in_features) if isinstance(l, nn.Linear) else l\n",
    "                for l in self.reverse_phi.layers[1:-1]\n",
    "            ]))\n",
    "        self.mse = torch.nn.MSELoss()\n",
    "        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n",
    "\n",
    "    def forward(self, *args, **kwargs):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def encode(self, x0):\n",
    "        return self.phi(x0)\n",
    "\n",
    "    def decode(self, z0):\n",
    "        return self.reverse_phi(z0)\n",
    "\n",
    "    def compute_loss(self, x0):\n",
    "        loss = self.mse(x0, self.decode(self.encode(x0)))\n",
    "        return loss\n",
    "\n",
    "    def train_batch(self, x0, *args, **kwargs):\n",
    "        self.train()\n",
    "        self.optimizer.zero_grad()\n",
    "        loss = self.compute_loss(x0)\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        return loss\n",
    "class ContrastiveNet(Network):\n",
    "    def __init__(self, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.frozen = False\n",
    "\n",
    "        self.layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(2 * n_latent_dims, 1)])\n",
    "        else:\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(2 * n_latent_dims, n_units_per_layer),\n",
    "                 torch.nn.Tanh()])\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),\n",
    "                 torch.nn.Tanh()] * (n_hidden_layers - 1))\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, 1)])\n",
    "        self.layers.extend([torch.nn.Sigmoid()])\n",
    "        self.model = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, z0, z1):\n",
    "        context = torch.cat((z0, z1), -1)\n",
    "        fakes = self.model(context).squeeze()\n",
    "        return fakes\n",
    "class CPCNet(Network):\n",
    "    def __init__(self, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.frozen = False\n",
    "\n",
    "        self.layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(2*n_latent_dims, 1)])\n",
    "        else:\n",
    "            self.layers.extend([torch.nn.Linear(2*n_latent_dims, n_units_per_layer), torch.nn.Tanh()])\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_units_per_layer), torch.nn.Tanh()] * (n_hidden_layers-1))\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, 1)])\n",
    "        self.layers.extend([torch.nn.Sigmoid()])\n",
    "        self.model = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, c, z):\n",
    "        context = torch.cat((c, z), -1)\n",
    "        fakes = self.model(context).squeeze()\n",
    "        return fakes\n",
    "class FeatureNet(Network):\n",
    "    def __init__(self,\n",
    "                 n_actions,\n",
    "                 input_shape=2,\n",
    "                 n_latent_dims=4,\n",
    "                 n_hidden_layers=1,\n",
    "                 n_units_per_layer=32,\n",
    "                 lr=0.001,\n",
    "                 coefs=None):\n",
    "        super().__init__()\n",
    "        self.n_actions = n_actions\n",
    "        self.n_latent_dims = n_latent_dims\n",
    "        self.lr = lr\n",
    "        self.coefs = defaultdict(lambda: 1.0)\n",
    "        if coefs is not None:\n",
    "            for k, v in coefs.items():\n",
    "                self.coefs[k] = v\n",
    "\n",
    "        self.phi = PhiNet(input_shape=input_shape,\n",
    "                          n_latent_dims=n_latent_dims,\n",
    "                          n_units_per_layer=n_units_per_layer,\n",
    "                          n_hidden_layers=n_hidden_layers)\n",
    "        # self.fwd_model = FwdNet(n_actions=n_actions, n_latent_dims=n_latent_dims, n_hidden_layers=n_hidden_layers, n_units_per_layer=n_units_per_layer)\n",
    "        self.inv_model = InvNet(n_actions=n_actions,\n",
    "                                n_latent_dims=n_latent_dims,\n",
    "                                n_units_per_layer=n_units_per_layer,\n",
    "                                n_hidden_layers=n_hidden_layers)\n",
    "        self.inv_discriminator = InvDiscriminator(n_actions=n_actions,\n",
    "                                                  n_latent_dims=n_latent_dims,\n",
    "                                                  n_units_per_layer=n_units_per_layer,\n",
    "                                                  n_hidden_layers=n_hidden_layers)\n",
    "        self.discriminator = ContrastiveNet(n_latent_dims=n_latent_dims,\n",
    "                                            n_hidden_layers=1,\n",
    "                                            n_units_per_layer=n_units_per_layer)\n",
    "\n",
    "        self.cross_entropy = torch.nn.CrossEntropyLoss()\n",
    "        self.bce_loss = torch.nn.BCELoss()\n",
    "        self.mse = torch.nn.MSELoss()\n",
    "        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n",
    "\n",
    "    def inverse_loss(self, z0, z1, a):\n",
    "        if self.coefs['L_inv'] == 0.0:\n",
    "            return torch.tensor(0.0)\n",
    "        a_hat = self.inv_model(z0, z1)\n",
    "        return self.cross_entropy(input=a_hat, target=a)\n",
    "\n",
    "    def contrastive_inverse_loss(self, z0, z1, a):\n",
    "        if self.coefs['L_coinv'] == 0.0:\n",
    "            return torch.tensor(0.0)\n",
    "        N = len(z0)\n",
    "        # shuffle next states\n",
    "        idx = torch.randperm(N)\n",
    "\n",
    "        a_neg = torch.randint_like(a, low=0, high=self.n_actions)\n",
    "\n",
    "        # concatenate positive and negative examples\n",
    "        z0_extended = torch.cat([z0, z0], dim=0)\n",
    "        z1_extended = torch.cat([z1, z1], dim=0)\n",
    "        a_pos_neg = torch.cat([a, a_neg], dim=0)\n",
    "        is_fake = torch.cat([torch.zeros(N), torch.ones(N)], dim=0)\n",
    "\n",
    "        # Compute which ones are fakes\n",
    "        fakes = self.inv_discriminator(z0_extended, z1_extended, a_pos_neg)\n",
    "        return self.bce_loss(input=fakes, target=is_fake.float())\n",
    "\n",
    "    def ratio_loss(self, z0, z1):\n",
    "        if self.coefs['L_rat'] == 0.0:\n",
    "            return torch.tensor(0.0)\n",
    "        N = len(z0)\n",
    "        # shuffle next states\n",
    "        idx = torch.randperm(N)\n",
    "        z1_neg = z1.view(N, -1)[idx].view(z1.size())\n",
    "\n",
    "        # concatenate positive and negative examples\n",
    "        z0_extended = torch.cat([z0, z0], dim=0)\n",
    "        z1_pos_neg = torch.cat([z1, z1_neg], dim=0)\n",
    "        is_fake = torch.cat([torch.zeros(N), torch.ones(N)], dim=0)\n",
    "\n",
    "        # Compute which ones are fakes\n",
    "        fakes = self.discriminator(z0_extended, z1_pos_neg)\n",
    "        return self.bce_loss(input=fakes, target=is_fake.float())\n",
    "\n",
    "    def distance_loss(self, z0, z1):\n",
    "        if self.coefs['L_dis'] == 0.0:\n",
    "            return torch.tensor(0.0)\n",
    "        dz = torch.norm(z1 - z0, dim=-1, p=2)\n",
    "        with torch.no_grad():\n",
    "            max_dz = 0.1\n",
    "        excess = torch.nn.functional.relu(dz - max_dz)\n",
    "        return self.mse(excess, torch.zeros_like(excess))\n",
    "\n",
    "    def oracle_loss(self, z0, z1, d):\n",
    "        if self.coefs['L_ora'] == 0.0:\n",
    "            return torch.tensor(0.0)\n",
    "\n",
    "        dz = torch.cat(\n",
    "            [torch.norm(z1 - z0, dim=-1, p=2),\n",
    "             torch.norm(z1.flip(0) - z0, dim=-1, p=2)], dim=0)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            counts = 1 + torch.histc(d, bins=36, min=0, max=35)\n",
    "            inverse_counts = counts.sum() / counts\n",
    "            weights = inverse_counts[d.long()]\n",
    "            weights = weights / weights.sum()\n",
    "\n",
    "        loss = self.mse(dz, d / 10.0)\n",
    "        # loss += torch.sum(weights * (dz - d / 20.0)**2) # weighted MSE\n",
    "        # loss = -torch.nn.functional.cosine_similarity(dz, d, 0)\n",
    "        return loss\n",
    "\n",
    "    def forward(self, *args, **kwargs):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def predict_a(self, z0, z1):\n",
    "        raise NotImplementedError\n",
    "        # a_logits = self.inv_model(z0, z1)\n",
    "        # return torch.argmax(a_logits, dim=-1)\n",
    "\n",
    "    def compute_loss(self, z0, z1, a):\n",
    "        loss = 0\n",
    "        loss += self.coefs['L_coinv'] * self.contrastive_inverse_loss(z0, z1, a)\n",
    "        loss += self.coefs['L_inv'] * self.inverse_loss(z0, z1, a)\n",
    "        # loss += self.coefs['L_fwd'] * self.compute_fwd_loss(z0, z1, z1_hat)\n",
    "        loss += self.coefs['L_rat'] * self.ratio_loss(z0, z1)\n",
    "        loss += self.coefs['L_dis'] * self.distance_loss(z0, z1)\n",
    "        return loss\n",
    "\n",
    "    def train_batch(self, x0, x1, a):\n",
    "        self.train()\n",
    "        self.optimizer.zero_grad()\n",
    "        z0 = self.phi(x0)\n",
    "        z1 = self.phi(x1)\n",
    "        # z1_hat = self.fwd_model(z0, a)\n",
    "        loss = self.compute_loss(z0, z1, a)\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        return loss\n",
    "class InvDiscriminator(Network):\n",
    "    def __init__(self, n_actions, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.n_actions = n_actions\n",
    "\n",
    "        self.layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(2 * n_latent_dims + n_actions, 1)])\n",
    "        else:\n",
    "            self.layers.extend([\n",
    "                torch.nn.Linear(2 * n_latent_dims + n_actions, n_units_per_layer),\n",
    "                torch.nn.Tanh()\n",
    "            ])\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),\n",
    "                 torch.nn.Tanh()] * (n_hidden_layers - 1))\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, 1)])\n",
    "        self.layers.extend([torch.nn.Sigmoid()])\n",
    "        self.model = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, z0, z1, a):\n",
    "        context = torch.cat((z0, z1, one_hot(a, self.n_actions)), -1)\n",
    "        fakes = self.model(context).squeeze()\n",
    "        return fakes\n",
    "class FwdNet(Network):\n",
    "    def __init__(self, n_actions, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.n_actions = n_actions\n",
    "        self.frozen = False\n",
    "\n",
    "        self.fwd_layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.fwd_layers.extend([torch.nn.Linear(n_latent_dims+self.n_actions, n_latent_dims)])\n",
    "        else:\n",
    "            self.fwd_layers.extend([torch.nn.Linear(n_latent_dims + self.n_actions, n_units_per_layer), torch.nn.Tanh()])\n",
    "            self.fwd_layers.extend([torch.nn.Linear(n_units_per_layer, n_units_per_layer), torch.nn.Tanh()] * (n_hidden_layers-1))\n",
    "            self.fwd_layers.extend([torch.nn.Linear(n_units_per_layer, n_latent_dims)])\n",
    "        # self.fwd_layers.extend([torch.nn.BatchNorm1d(n_latent_dims, affine=False)])\n",
    "        self.fwd_model = torch.nn.Sequential(*self.fwd_layers)\n",
    "\n",
    "    def forward(self, z, a):\n",
    "        a_onehot = one_hot(a, depth=self.n_actions)\n",
    "        context = torch.cat((z, a_onehot), -1)\n",
    "        z_hat = self.fwd_model(context)\n",
    "        return z_hat\n",
    "class InvNet(Network):\n",
    "    def __init__(self, n_actions, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.n_actions = n_actions\n",
    "\n",
    "        self.layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(2 * n_latent_dims, n_actions)])\n",
    "        else:\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(2 * n_latent_dims, n_units_per_layer),\n",
    "                 torch.nn.Tanh()])\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),\n",
    "                 torch.nn.Tanh()] * (n_hidden_layers - 1))\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_actions)])\n",
    "\n",
    "        self.inv_model = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, z0, z1):\n",
    "        context = torch.cat((z0, z1), -1)\n",
    "        a_logits = self.inv_model(context)\n",
    "        return a_logits\n",
    "class Reshape(torch.nn.Module):\n",
    "    \"\"\"Module that returns a view of the input which has a different size\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    args : int...\n",
    "        The desired size\n",
    "    \"\"\"\n",
    "    def __init__(self, *args):\n",
    "        super().__init__()\n",
    "        self.shape = args\n",
    "\n",
    "    def __repr__(self):\n",
    "        s = self.__class__.__name__\n",
    "        s += '{}'.format(self.shape)\n",
    "        return s\n",
    "\n",
    "    def forward(self, input):\n",
    "        return input.view(*self.shape)\n",
    "\n",
    "\n",
    "class Sequential(torch.nn.Sequential, Network):\n",
    "    pass\n",
    "\n",
    "def one_hot(x, depth, dtype=torch.float32):\n",
    "    \"\"\"Convert a batch of indices to a batch of one-hot vectors\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    depth : int\n",
    "        The length of each output vector\n",
    "    \"\"\"\n",
    "    i = x.unsqueeze(-1).expand(-1, depth)\n",
    "    return torch.zeros_like(i, dtype=dtype).scatter_(-1, i, 1)\n",
    "\n",
    "def extract(input, idx, idx_dim, batch_dim=0):\n",
    "    '''\n",
    "Extracts slices of input tensor along idx_dim at positions\n",
    "specified by idx.\n",
    "\n",
    "Notes:\n",
    "    idx must have the same size as input.shape[batch_dim].\n",
    "    Output tensor has the shape of input with idx_dim removed.\n",
    "\n",
    "Args:\n",
    "    input (Tensor): the source tensor\n",
    "    idx (LongTensor): the indices of slices to extract\n",
    "    idx_dim (int): the dimension along which to extract slices\n",
    "    batch_dim (int): the dimension to treat as the batch dimension\n",
    "\n",
    "Example::\n",
    "\n",
    "    >>> t = torch.arange(24, dtype=torch.float32).view(3,4,2)\n",
    "    >>> i = torch.tensor([1, 3, 0], dtype=torch.int64)\n",
    "    >>> extract(t, i, idx_dim=1, batch_dim=0)\n",
    "        tensor([[ 2.,  3.],\n",
    "                [14., 15.],\n",
    "                [16., 17.]])\n",
    "'''\n",
    "    if idx_dim == batch_dim:\n",
    "        raise RuntimeError('idx_dim cannot be the same as batch_dim')\n",
    "    if len(idx) != input.shape[batch_dim]:\n",
    "        raise RuntimeError(\n",
    "            \"idx length '{}' not compatible with batch_dim '{}' for input shape '{}'\".format(\n",
    "                len(idx), batch_dim, list(input.shape)))\n",
    "    viewshape = [\n",
    "        1,\n",
    "    ] * input.ndimension()\n",
    "    viewshape[batch_dim] = input.shape[batch_dim]\n",
    "    idx = idx.view(*viewshape).expand_as(input)\n",
    "    result = torch.gather(input, idx_dim, idx).mean(dim=idx_dim)\n",
    "    return result\n",
    "class PhiNet(Network):\n",
    "    def __init__(self,\n",
    "                 input_shape=2,\n",
    "                 n_latent_dims=4,\n",
    "                 n_hidden_layers=1,\n",
    "                 n_units_per_layer=32,\n",
    "                 final_activation=torch.nn.Tanh):\n",
    "        super().__init__()\n",
    "        self.input_shape = input_shape\n",
    "\n",
    "        shape_flat = np.prod(self.input_shape)\n",
    "\n",
    "        self.layers = []\n",
    "        self.layers.extend([Reshape(-1, shape_flat)])\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(shape_flat, n_latent_dims)])\n",
    "        else:\n",
    "            self.layers.extend([torch.nn.Linear(shape_flat, n_units_per_layer), torch.nn.Tanh()])\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),\n",
    "                 torch.nn.Tanh()] * (n_hidden_layers - 1))\n",
    "            self.layers.extend([\n",
    "                torch.nn.Linear(n_units_per_layer, n_latent_dims),\n",
    "            ])\n",
    "        if final_activation is not None:\n",
    "            self.layers.extend([final_activation()])\n",
    "        self.phi = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.phi(x)\n",
    "        return z\n",
    "class QNet(Network):\n",
    "    def __init__(self, n_features, n_actions, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.n_actions = n_actions\n",
    "\n",
    "        self.layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(n_features, n_actions)])\n",
    "        else:\n",
    "            self.layers.extend([torch.nn.Linear(n_features, n_units_per_layer), torch.nn.ReLU()])\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),\n",
    "                 torch.nn.ReLU()] * (n_hidden_layers - 1))\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_actions)])\n",
    "\n",
    "        self.model = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, z):\n",
    "        return self.model(z)\n",
    "class SimpleNet(Network):\n",
    "    def __init__(self, n_inputs, n_outputs, n_hidden_layers=1, n_units_per_layer=32):\n",
    "        super().__init__()\n",
    "        self.n_outputs = n_outputs\n",
    "        self.frozen = False\n",
    "\n",
    "        self.layers = []\n",
    "        if n_hidden_layers == 0:\n",
    "            self.layers.extend([torch.nn.Linear(n_inputs, n_outputs)])\n",
    "        else:\n",
    "            self.layers.extend([torch.nn.Linear(n_inputs, n_units_per_layer), torch.nn.Tanh()])\n",
    "            self.layers.extend(\n",
    "                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),\n",
    "                 torch.nn.Tanh()] * (n_hidden_layers - 1))\n",
    "            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_outputs)])\n",
    "\n",
    "        self.model = torch.nn.Sequential(*self.layers)\n",
    "\n",
    "    def forward(self, z0):\n",
    "        a_logits = self.model(z0)\n",
    "        return a_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "40727b3c-de74-4263-90e8-f64e3da44e78",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "# parser.add_argument('-d','--dims', help='Number of latent dimensions', type=int, default=2)\n",
    "# yapf: disable\n",
    "parser.add_argument('--type', type=str, default='markov', choices=['markov', 'autoencoder', 'pixel-predictor'],\n",
    "                    help='Which type of representation learning method')\n",
    "parser.add_argument('-n','--n_updates', type=int, default=3000,\n",
    "                    help='Number of training updates')\n",
    "parser.add_argument('-r','--rows', type=int, default=6,\n",
    "                    help='Number of gridworld rows')\n",
    "parser.add_argument('-c','--cols', type=int, default=6,\n",
    "                    help='Number of gridworld columns')\n",
    "parser.add_argument('-w', '--walls', type=str, default='empty', choices=['empty', 'maze', 'spiral', 'loop'],\n",
    "                    help='The wall configuration mode of gridworld')\n",
    "parser.add_argument('-l','--latent_dims', type=int, default=2,\n",
    "                    help='Number of latent dimensions to use for representation')\n",
    "parser.add_argument('--L_inv', type=float, default=1.0,\n",
    "                    help='Coefficient for inverse-model-matching loss')\n",
    "parser.add_argument('--L_coinv', type=float, default=0.0,\n",
    "                    help='Coefficient for *contrastive* inverse-model-matching loss')\n",
    "# parser.add_argument('--L_fwd', type=float, default=0.0,\n",
    "#                     help='Coefficient for forward dynamics loss')\n",
    "parser.add_argument('--L_rat', type=float, default=1.0,\n",
    "                    help='Coefficient for ratio-matching loss')\n",
    "# parser.add_argument('--L_fac', type=float, default=0.0,\n",
    "#                     help='Coefficient for factorization loss')\n",
    "parser.add_argument('--L_dis', type=float, default=0.0,\n",
    "                    help='Coefficient for planning-distance loss')\n",
    "parser.add_argument('--L_ora', type=float, default=0.0,\n",
    "                    help='Coefficient for oracle distance loss')\n",
    "parser.add_argument('-lr','--learning_rate', type=float, default=0.003,\n",
    "                    help='Learning rate for Adam optimizer')\n",
    "parser.add_argument('--batch_size', type=int, default=2048,\n",
    "                    help='Mini batch size for training updates')\n",
    "parser.add_argument('-s','--seed', type=int, default=0,\n",
    "                    help='Random seed')\n",
    "parser.add_argument('-t','--tag', type=str, required=True,\n",
    "                    help='Tag for identifying experiment')\n",
    "parser.add_argument('-v','--video', action='store_true',\n",
    "                    help=\"Save training video\")\n",
    "parser.add_argument('--no_graphics', action='store_true',\n",
    "                    help='Turn off graphics (e.g. for running on cluster)')\n",
    "parser.add_argument('--save', action='store_true',\n",
    "                    help='Save final network weights')\n",
    "parser.add_argument('--cleanvis', action='store_true',\n",
    "                    help='Switch to representation-only visualization')\n",
    "parser.add_argument('--no_sigma', action='store_true',\n",
    "                    help='Turn off sensors and just use true state; i.e. x=s')\n",
    "parser.add_argument('--rearrange_xy', action='store_true',\n",
    "                    help='Rearrange discrete x-y positions to break smoothness')\n",
    "if 'ipykernel' in sys.argv[0]:\n",
    "    arglist = [\n",
    "        '--type', 'markov',\n",
    "        '-w', 'spiral',\n",
    "        '--tag', 'test-spiral',\n",
    "        '-r', '6',\n",
    "        '-c', '6',\n",
    "        '--L_ora', '1.0',\n",
    "        '--save'\n",
    "    ]\n",
    "    args = parser.parse_args(arglist)\n",
    "else:\n",
    "    args = parser.parse_args()\n",
    "if args.no_graphics:\n",
    "    import matplotlib\n",
    "    # Force matplotlib to not use any Xwindows backend.\n",
    "    matplotlib.use('Agg')\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "log_dir = 'results/logs/' + str(args.tag)\n",
    "vid_dir = 'results/videos/' + str(args.tag)\n",
    "maze_dir = 'results/mazes/' + str(args.tag)\n",
    "os.makedirs(log_dir, exist_ok=True)\n",
    "if args.video:\n",
    "    os.makedirs(vid_dir, exist_ok=True)\n",
    "    os.makedirs(maze_dir, exist_ok=True)\n",
    "    video_filename = vid_dir + '/video-{}.mp4'.format(args.seed)\n",
    "    image_filename = vid_dir + '/final-{}.png'.format(args.seed)\n",
    "    maze_file = maze_dir + '/maze-{}.png'.format(args.seed)\n",
    "\n",
    "log = open(log_dir + '/train-{}.txt'.format(args.seed), 'w')\n",
    "with open(log_dir + '/args-{}.txt'.format(args.seed), 'w') as arg_file:\n",
    "    arg_file.write(repr(args))\n",
    "\n",
    "seeding.seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "17574b0c-8c57-46ec-9501-9479b133d232",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class FQE_eval(torch.nn.Module):\n",
    "    def __init__(self, in_dim, action_size, n_layers=2, n_nodes=32, activation=nn.ReLU()):\n",
    "        super().__init__()\n",
    "        self.action_size = action_size\n",
    "\n",
    "        self.net = []\n",
    "        self.net.append(nn.Linear(in_dim, n_nodes))\n",
    "        self.net.append(activation)\n",
    "\n",
    "        for i in range(n_layers-1):\n",
    "            self.net.append(nn.Linear(n_nodes, n_nodes))\n",
    "            self.net.append(activation)\n",
    "\n",
    "        self.net.append(nn.Linear(n_nodes, action_size))\n",
    "        self.FQE_net = nn.Sequential(*self.net)\n",
    "\n",
    "        self.train()\n",
    "    def forward(self, x):\n",
    "        x = self.FQE_net(x)\n",
    "        return x\n",
    "    \n",
    "def train_FQE_step(model, optimizer, x, a, r, x_next, terminal, observed_s_next, target_policy, gamma=0.99):\n",
    "    optimizer.zero_grad()\n",
    "    model.train()\n",
    "    criterion_FQE = nn.MSELoss()\n",
    "\n",
    "    batch_size = x.shape[0]\n",
    "    order = torch.arange(batch_size)\n",
    "    pi_s_next = target_policy(observed_s_next)  #the policy is based on observed state space\n",
    "\n",
    "    outputs_FQE = model(x)\n",
    "    with torch.no_grad():\n",
    "        FQE_next = model(x_next)\n",
    "\n",
    "    FQE_targets = outputs_FQE.detach().clone()\n",
    "\n",
    "    FQE_targets[order, a] = r + gamma * FQE_next[order, pi_s_next] * (torch.ones(batch_size) - terminal)\n",
    "\n",
    "    loss_FQE = criterion_FQE(outputs_FQE, FQE_targets)\n",
    "\n",
    "    loss_FQE.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    return loss_FQE.item()\n",
    "\n",
    "def train_FQE(data, num_epochs, target_policy, n_layers=3, n_nodes=32, lr=0.001):  \n",
    "    #data = [x,a,r,x',terminal,s']\n",
    "    obs_size = data[0][0].shape[0]\n",
    "    action_size =len(torch.unique(data[1]))\n",
    "    model = FQE_eval(obs_size, action_size, n_layers, n_nodes).double()\n",
    "    mse_loss = nn.MSELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    \n",
    "    terminal_index = torch.nonzero(data[4]).squeeze().long() #the next index of True terminal is the initial state of the next episode\n",
    "    terminal_index = terminal_index[:-1] #remove the last terminal state index\n",
    "    init_index = torch.cat([torch.tensor([0]), terminal_index+1]) #the first state is always initial\n",
    "    observed_init_index = torch.cat([torch.tensor([0]), terminal_index]) #use s_next to get initial observed states\n",
    "    initial_x = data[0][init_index]\n",
    "    observed_init = data[5][observed_init_index]\n",
    "    target_init = target_policy(observed_init) \n",
    "    num_episode = initial_x.shape[0]\n",
    "    \n",
    "    batch_size = max((data[1].shape[0])//20,10)\n",
    "    dataset = TensorDataset(*data)\n",
    "    batch_data = DataLoader(dataset, batch_size=batch_size)\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        for x, a, r, x_next, terminal, observed_s_next in batch_data:\n",
    "            batch_loss = train_FQE_step(model, optimizer, x, a, r, x_next, terminal, observed_s_next, target_policy)\n",
    "            \n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            preds = model(initial_x) #Q-value estimation is based on abstracted space\n",
    "        estimated_value = preds[np.arange(num_episode), target_init]\n",
    "        estimated_value = estimated_value.mean()\n",
    "    return estimated_value\n",
    "\n",
    "def behavior_policy(state, epsilon=0):\n",
    "    angle = state[2]\n",
    "    if np.random.binomial(1, epsilon) == 1:\n",
    "        return np.random.choice([0,1])\n",
    "    else:\n",
    "        if angle < 0:\n",
    "            return 0\n",
    "        else:\n",
    "            return 1\n",
    "        \n",
    "def random_policy(state, action_size=2, batch=True):\n",
    "    if batch:\n",
    "        size = state.shape[0]\n",
    "        return torch.randint(0, action_size, (size,)).long()\n",
    "    \n",
    "    return np.random.choice([0, 1])\n",
    "\n",
    "def nondyna_policy(state, action=1, batch=True):\n",
    "    if batch:\n",
    "        size = state.shape[0]\n",
    "        return torch.randint(action, action+1, (size,)).long()\n",
    "    \n",
    "    return action\n",
    "def cartpole_policy(state, batch=True):\n",
    "    if batch:\n",
    "        pos = state[:,0]\n",
    "        angle = state[:,2]\n",
    "        prob_1 = 1 - 1/(1+torch.exp(angle-pos))\n",
    "        return torch.bernoulli(prob_1).long()\n",
    "    \n",
    "    pos = state[0]\n",
    "    angle = state[2]\n",
    "    prob_0 = 1/(1+np.exp(angle-pos))\n",
    "    prob_1 = 1 - prob_0\n",
    "    return np.random.binomial(1, prob_1)\n",
    "\n",
    "def angle_policy(state, batch=True):\n",
    "    if batch:\n",
    "        angle = state[:,2]\n",
    "        return (angle>=0).long()\n",
    "    \n",
    "    return behavior_policy(state)\n",
    "\n",
    "def plot_helper(df, title, xticks=None, xlabel=\"x\", ylabel=\"y\"):\n",
    "    plt.figure(figsize=(6, 4))\n",
    "    if xticks is None:\n",
    "        xticks = np.arange(df.shape[0])\n",
    "        set_xtick = False\n",
    "    else:\n",
    "        set_xtick = True\n",
    "    for column in df.columns:\n",
    "        plt.scatter(xticks, df[column], label=column[0])\n",
    "        plt.plot(xticks, df[column], linestyle='-')\n",
    "    if set_xtick:\n",
    "        plt.xticks(xticks)\n",
    "    #plt.axhline(y=hline, color='r', linestyle='-')\n",
    "    plt.title(title)\n",
    "    plt.xlabel(xlabel)\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "class DQN(nn.Module):\n",
    "    def __init__(self, in_dim, action_size):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(in_dim, 64)\n",
    "        self.fc2 = nn.Linear(64, 64)\n",
    "        self.fc3 = nn.Linear(64, action_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "def target_lunar(state, batch=True):\n",
    "    policy_net.eval()\n",
    "    with torch.no_grad():\n",
    "        if batch:\n",
    "            a = torch.argmax(policy_net(state[:,:8]), axis=1).long()\n",
    "            return a\n",
    "        else:\n",
    "            state = torch.tensor(state).double().view(1,-1)\n",
    "            a = np.argmax(policy_net(state).squeeze(1).numpy())\n",
    "            return a  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "230c7222-6342-49b8-ad8d-aee5f5f5f056",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "policy_net = DQN(8, 4)\n",
    "model_path = \"dqn_lunar_lander.pt\"\n",
    "policy_net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) #loaidng trained model\n",
    "policy_net = policy_net.double()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27838d83-6f2f-48bb-ae7e-5d7391a18d78",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Markov abstraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "cb41af47-2466-46c0-a011-0cd89f5806a2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "coefs = {\n",
    "    'L_inv': args.L_inv,\n",
    "    'L_coinv': args.L_coinv,\n",
    "    # 'L_fwd': args.L_fwd,\n",
    "    'L_rat': args.L_rat,\n",
    "    # 'L_fac': args.L_fac,\n",
    "    'L_dis': args.L_dis,\n",
    "    'L_ora': args.L_ora,\n",
    "}\n",
    "#target_policy = angle_policy\n",
    "target_policy = target_lunar\n",
    "def get_batch(x0, x1, a, batch_size=10):\n",
    "    idx = np.random.choice(len(a), batch_size, replace=False)\n",
    "    tx0 = torch.as_tensor(x0[idx]).float()\n",
    "    tx1 = torch.as_tensor(x1[idx]).float()\n",
    "    ta = torch.as_tensor(a[idx]).long()\n",
    "    ti = torch.as_tensor(idx).long()\n",
    "    return tx0, tx1, ta, idx\n",
    "\n",
    "def encode(data, model):\n",
    "    #Encode the original (s,a,r,s') tuple by forward abstraction\n",
    "    #data = [s,a,r,s',terminal] or [x,a,r,x',terminal,s,s']\n",
    "    data = copy.deepcopy(data)\n",
    "    sample_s = torch.cat([data[0], data[3][-1].unsqueeze(0)]).float()\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        all_states = model.phi(sample_s)\n",
    "        phi_s = all_states[:-1]\n",
    "        phi_s_next = all_states[1:]\n",
    "    data[0] = phi_s.double()\n",
    "    data[3] = phi_s_next.double()\n",
    "    return data\n",
    "\n",
    "def train_abstraction(data, action_size, n_frames, n_updates_per_frame, coefs, type):\n",
    "    s = data[0]\n",
    "    a = data[1]\n",
    "    s_next= data[3]\n",
    "    input_shape = s.shape[1:]\n",
    "    batch_size = max((s.shape[0])//20, 10)\n",
    "    if type == 'markov':\n",
    "        fnet = FeatureNet(n_actions=action_size,\n",
    "                      input_shape=input_shape,\n",
    "                      n_latent_dims=args.latent_dims,\n",
    "                      n_hidden_layers=1,\n",
    "                      n_units_per_layer=32,\n",
    "                      lr=args.learning_rate,\n",
    "                      coefs=coefs)\n",
    "    else:\n",
    "        fnet = AutoEncoder(n_actions=action_size,\n",
    "                       input_shape=input_shape,\n",
    "                       n_latent_dims=args.latent_dims,\n",
    "                       n_hidden_layers=1,\n",
    "                       n_units_per_layer=32,\n",
    "                       lr=args.learning_rate,\n",
    "                       coefs=coefs)\n",
    "    get_next_batch = (lambda: get_batch(s, s_next, a, batch_size))\n",
    "    for frame_idx in tqdm(range(n_frames + 1)):\n",
    "        for _ in range(n_updates_per_frame):\n",
    "            tx0, tx1, ta, idx = get_next_batch()\n",
    "            fnet.train_batch(tx0, tx1, ta)\n",
    "    return fnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "09aac743-e1d1-4fdc-91f3-b00605554bc8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open(\"lunar_0.1_data_bfbfnew.pickle\", 'rb') as f:\n",
    "    data01 = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "272e8ac3-f662-4f66-aab7-6e3f73529ef0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open(\"/home/jupyter/03_data.pickle\", 'rb') as f:\n",
    "    sars_by_episode = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6748377a-8291-449d-84dc-f63d322618c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b2ccbf44-7598-465f-954a-3fabf30b1c03",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 21/21 [00:10<00:00,  1.95it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.47it/s]\n",
      "100%|██████████| 21/21 [00:05<00:00,  3.63it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.25it/s]\n",
      "100%|██████████| 21/21 [00:05<00:00,  4.09it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.34it/s]\n",
      "100%|██████████| 21/21 [00:05<00:00,  3.86it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.95it/s]\n",
      "100%|██████████| 21/21 [00:06<00:00,  3.41it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.04it/s]\n",
      "100%|██████████| 21/21 [00:06<00:00,  3.23it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.97it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.19it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 14.48it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.33it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 14.22it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.12it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 13.20it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.82it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 13.05it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.10it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 12.90it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  9.03it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 16.85it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.18it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 13.59it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.41it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 13.53it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.17it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 14.18it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  9.38it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 17.49it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.97it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 13.43it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.41it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 14.29it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.95it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 16.32it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.97it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 13.36it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.67it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.85it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.12it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00, 10.10it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.68it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.33it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.32it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.99it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.37it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.87it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.67it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.18it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.82it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.03it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.63it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.92it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.25it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.04it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.46it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.77it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.40it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.23it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.49it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.31it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.34it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.54it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.45it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.37it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.61it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.38it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.23it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  9.46it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.69it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.81it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.82it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 10.80it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.28it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  9.39it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.33it/s]\n",
      "100%|██████████| 21/21 [00:01<00:00, 11.88it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.65it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.49it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.72it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.68it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.65it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.78it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.73it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  9.05it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.48it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.59it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.74it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.73it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.39it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.45it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.11it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.15it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.26it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.30it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.80it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.56it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.66it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.75it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.70it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.64it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.74it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.99it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.42it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  7.28it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.62it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.61it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.14it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.34it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.41it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.68it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.51it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.42it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.90it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.12it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.10it/s]\n",
      "100%|██████████| 21/21 [00:02<00:00,  8.11it/s]\n",
      "100%|██████████| 21/21 [00:06<00:00,  3.42it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  4.81it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  4.95it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.69it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.31it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.45it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.38it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.44it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  4.98it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.21it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.30it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.56it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.42it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.53it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.04it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.73it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  4.76it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.07it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.37it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.68it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.39it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.71it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.28it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.32it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.35it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.17it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.07it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.31it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.47it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.45it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.10it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  5.90it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.11it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.30it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.05it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.43it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.24it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.27it/s]\n",
      "100%|██████████| 21/21 [00:04<00:00,  5.16it/s]\n",
      "100%|██████████| 21/21 [00:03<00:00,  6.03it/s]\n"
     ]
    }
   ],
   "source": [
    "n_updates_per_frame = 20\n",
    "n_frames = 20\n",
    "sample_sizes = [10,20,35,60]\n",
    "total_bias = []\n",
    "for sample_size in sample_sizes:\n",
    "    samples = data01[sample_size] #samples contain 30 sample, each consists of data from #sample_size of episodes\n",
    "    oracle = torch.tensor(61.7).repeat(2)\n",
    "    biases = []\n",
    "    #data_list = []\n",
    "    for sample in samples:\n",
    "        #sample=[s,a,r,s',terminal]\n",
    "        s_next = sample[3]\n",
    "        markov_model = train_abstraction(sample, 4, n_frames, n_updates_per_frame, coefs, \"markov\")\n",
    "        autoencoder = train_abstraction(sample, 4, n_frames, n_updates_per_frame, coefs, \"autoencoder\")\n",
    "        \n",
    "        markov_sample = [*encode(sample, markov_model), s_next]\n",
    "        auto_sample = [*encode(sample, autoencoder), s_next]\n",
    "\n",
    "        markov_value = train_FQE(markov_sample, 20, target_policy)\n",
    "        auto_value = train_FQE(auto_sample, 20, target_policy)\n",
    "        \n",
    "        bias = torch.tensor([markov_value, auto_value ]) - oracle\n",
    "        biases.append(bias)\n",
    "    total_bias.append(biases)\n",
    "\n",
    "with open(\"lunar0.1_markov_auto_bias.pickle\", \"wb\") as f:\n",
    "    pickle.dump(total_bias,f)"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "conda-env-pytorch-pytorch",
   "name": "workbench-notebooks.m125",
   "type": "gcloud",
   "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125"
  },
  "kernelspec": {
   "display_name": "PyTorch 1-13 (Local)",
   "language": "python",
   "name": "conda-env-pytorch-pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
