import argparse
import copy
import gymnasium as gym
import itertools
import logging
import numpy as np
import os
import pickle
import random
import torch
import tqdm
import sys

from collections import namedtuple

from .. import delayed_mdp
import latency_env.misc.argparser_types as at
from latency_env.misc.argparser import Argument as Arg
from latency_env.misc.argparser import ArgumentList as ArgList

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


def seed_all(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


LOG_WRITER_TYPE_NONE = 0
LOG_WRITER_TYPE_TENSORBOARD = 1

LOG_WRITER = {
    "state": None,
    "type": LOG_WRITER_TYPE_NONE,
}

def log_write_scalaridx(name, value, idx):
    if LOG_WRITER["type"] == LOG_WRITER_TYPE_TENSORBOARD:
        LOG_WRITER["state"].add_scalar(name, value, idx)

def log_write_text(name, msg):
    if LOG_WRITER["type"] == LOG_WRITER_TYPE_TENSORBOARD:
        LOG_WRITER["state"].add_text(name, msg.replace("\n", "\n\n"))

def log_write_histogramidx(name, value, idx):
    if LOG_WRITER["type"] == LOG_WRITER_TYPE_TENSORBOARD:
        LOG_WRITER["state"].add_histogram(name, value, idx)

def log_write_flush():
    if LOG_WRITER["type"] == LOG_WRITER_TYPE_TENSORBOARD:
        LOG_WRITER["state"].flush()

def log_write_close():
    if LOG_WRITER["type"] == LOG_WRITER_TYPE_TENSORBOARD:
        LOG_WRITER["state"].close()
        LOG_WRITER["state"] = None
        LOG_WRITER["type"] = LOG_WRITER_TYPE_NONE

def log_write_setup_tensorboard(summary_path, append_date=True):
    from torch.utils.tensorboard import SummaryWriter
    from datetime import datetime
    from pathlib import Path
    summary_path = Path(summary_path)
    if append_date:
        summary_path /= datetime.now().strftime("%Y-%m-%d_%H.%M.%S")
    os.makedirs(summary_path, mode=0o755, exist_ok=True)
    LOG_WRITER["state"] = SummaryWriter(summary_path)
    LOG_WRITER["type"] = LOG_WRITER_TYPE_TENSORBOARD
    return summary_path


def setup_logging(verbosity=None, log_stderr=False, logfile_name=None):
    LOG_FMT = logging.Formatter("%(asctime)s %(name)s:%(lineno)d [%(levelname)s]: %(message)s")
    if verbosity is not None:
        loglevels = [logging.INFO, logging.DEBUG]
        logging.getLogger().setLevel(loglevels[min(verbosity, len(loglevels)-1)])
    if log_stderr:
        stderr_handler = logging.StreamHandler(sys.stderr)
        stderr_handler.setFormatter(LOG_FMT)
        logging.getLogger().addHandler(stderr_handler)
    if logfile_name is not None:
        os.makedirs(os.path.dirname(logfile_name), mode=0o755, exist_ok=True)
        logfile_handler = logging.FileHandler(logfile_name)
        logfile_handler.setFormatter(LOG_FMT)
        logging.getLogger().addHandler(logfile_handler)

logging_arguments = ArgList(
    _hooks=[lambda args: setup_logging(
        verbosity=args.verbosity,
        logfile_name=args.logfile_name,
        log_stderr=args.log_stderr)
    ],
    verbosity = Arg("-v", "--verbose", action="count", default=0,
                    help="Verbosity level of log printouts"),
    logfile_name = Arg("--logfile", metavar="PATH", type=str, default=None,
                       help="Output log messages to a file"),
    log_stderr = Arg("-q", "--quiet", action="store_false",
                     help="Do not output log messages to stderr"),
)


OPTIMIZERS = {
    "adam": torch.optim.Adam,
    "adamw": torch.optim.AdamW,
}

def make_optimizer(*args, optim=None, **kwargs):
    if optim is None:
        raise ValueError("missing optimizer")
    elif optim not in OPTIMIZERS.keys():
        raise ValueError(f"invalid optimizer \"{optim}\"")

    params = itertools.chain(*tuple(
        [arg] if isinstance(arg, torch.nn.Parameter) else arg.parameters()
        for arg in args
    ))

    return OPTIMIZERS[optim](params, **kwargs)


class ReplayMemoryBase:
    ParameterOrder = []
    REPLACEMENT_FIFO = 1
    REPLACEMENT_UNIFORM = 2
    NO_REPLACEMENT = 3
    def __init__(self, maxsize: int, env: gym.core.Env,
                 replacement=REPLACEMENT_FIFO):
        self.maxsize = maxsize
        self.length = 0
        self.next_index = 0
        self.prev_index = None
        self.replacement = replacement
        self.total_adds = 0

        # Replay buffer storage
        self.rb = {}

        # Hooks for fetching/storing data for specific fields in the replay buffer
        # get_hooks[name] = fun(idxs) -> tensor
        # add_hooks[name] = fun(data, idx) -> ()
        self.get_hooks = {}
        self.add_hooks = {}

        # Trajectory information
        self.trajectory_starts = torch.full((self.maxsize,), False, dtype=torch.bool)
        self.trajectory_ends = torch.full((self.maxsize,), False, dtype=torch.bool)
        # by default, the first item inserted is the start of a new trajectory
        self.trajectory_upcoming = True

        # Warning flags to avoid excessive log messages
        self.__warningflag_replacement = False
        self.__warningflag_trajectory_uniform = False

        def dtype_to_torch_dtype(ty):
            FLOATS = [np.dtype("float16"), np.dtype("float32"), np.dtype("float64")]
            INTS = [np.dtype("int16"), np.dtype("int32"), np.dtype("int64")]
            if ty in FLOATS:
                return (torch.float32, np.float32)
            elif ty in INTS:
                return (torch.long, np.int64)
            else:
                raise ValueError(f"unsupported numpy dtype {ty}")

        self.s_shape = env.observation_space.shape
        (self.s_type, self.s_nptype) = dtype_to_torch_dtype(env.observation_space.dtype)
        if isinstance(env.action_space, gym.spaces.Discrete):
            self.a_shape = ()
            (self.a_type, self.a_nptype) = (torch.long, np.int64)
        else:
            self.a_shape = env.action_space.shape
            (self.a_type, self.a_nptype) = dtype_to_torch_dtype(env.action_space.dtype)

    def _getkeys(self, keys, idxs, device=None):
        Sample = namedtuple(type(self).__name__ + "_Sample", keys)
        data = []
        for name in keys:
            if name in self.get_hooks:
                d = self.get_hooks[name](idxs)
            else:
                d = self.rb[name][idxs]

            if device is not None:
                d = d.to(device)
            data.append(d)

        return Sample(*data)

    def _getall(self, idxs, device=None):
        return self._getkeys(self.ParameterOrder, idxs, device=device)

    def _add(self, *args):
        """Internal add function."""
        if len(args) != len(self.ParameterOrder):
            raise ValueError(f"Received {len(args)} values to add to the replay buffer. Expected {len(self.ParameterOrder)}")

        for arg, name in zip(args, self.ParameterOrder):
            if name in self.add_hooks:
                self.add_hooks[name](arg, self.next_index)
            else:
                self.rb[name][self.next_index] = torch.as_tensor(arg, dtype=self.rb[name].dtype)

        # By default, we set all added items as being in the middle of a
        # trajectory
        self.trajectory_starts[self.next_index] = False
        self.trajectory_ends[self.next_index] = False

        if self.trajectory_upcoming:
            self.trajectory_starts[self.next_index] = True
            self.trajectory_ends[self.next_index] = True
            self.trajectory_upcoming = False
        else:
            self.trajectory_starts[self.next_index] = False
            self.trajectory_ends[self.next_index] = True
            if self.prev_index is not None:
                self.trajectory_ends[self.prev_index] = False

    def __len__(self):
        return self.length

    def __getitem__(self, idxs):
        return self._getall(idxs)

    def reset(self):
        self.length = 0
        self.next_index = 0
        self.prev_index = None
        # Reset warning flags
        self.__warningflag_replacement = False
        self.__warningflag_trajectory_uniform = False

    def new_trajectory(self):
        if self.replacement == self.REPLACEMENT_UNIFORM:
            if not self.__warningflag_trajectory_uniform:
                LOG.warning("Trajectory information does not work with uniform replacement. Resulting data will likely be nonsensical.")
                self.__warningflag_trajectory_uniform = True

        self.trajectory_upcoming = True

    def add(self, *args, **kwargs):
        self.total_adds += 1
        if self.length < self.maxsize:
            # Always add if there is room over
            self._add(*args, **kwargs)
            self.prev_index = self.next_index
            self.next_index += 1
            self.length += 1
        else:
            if self.replacement == self.REPLACEMENT_FIFO:
                # Replace the oldest thing in the buffer
                self._add(*args, **kwargs)
                self.prev_index = self.next_index
                self.next_index += 1
            elif self.replacement == self.REPLACEMENT_UNIFORM:
                # With probability |B| / total_adds, replace something at random
                # in the buffer. Otherwise throw the inserted data away.
                p = self.length / self.total_adds
                if p > np.random.random():
                    self.prev_index = self.next_index
                    self.next_index = np.random.randint(self.length)
                    self._add(*args, **kwargs)
            elif self.replacement == self.NO_REPLACEMENT:
                if not self.__warningflag_replacement:
                    LOG.warning("Replay buffer is full, discarding added samples until next reset()")
                    self.__warningflag_replacement = True
        if self.next_index >= self.maxsize:
            self.next_index = 0

    def sample(self, n, keys=None, device=None):
        """
        Returns a tuple representing a sample from the buffer. If device is
        specified, then the sample is also moved to that device before being
        returned.

        Example usage:
        bs, ba, br, bsn, bd = replay.sample(n)
        for i in range(n):
            # Extract a single batch item
            s = bs[i]
            a = as[i]
            r = br[i]
            sn = bsn[i]
            bd = bd[i]
        """
        if self.length == 0:
            raise RuntimeError("Cannot sample from an empty buffer.")

        idxs = np.random.randint(0, self.length, size=(n,))
        if keys is not None:
            return self._getkeys(keys, idxs, device=device)
        else:
            return self._getall(idxs, device=device)

    def sample_trajectory(self, n, k, keys=None, device=None):
        """
        Samples n trajectories from the replay buffer, each of length at most
        k. The trajectories are returned together with an integer specifying
        the length of the sample trajectory.

        In a sense, we are sampling endpoints of the trajectories, and then
        potentially rolling them back to find the start of said trajectory if
        extend_truncated_ends is True.

        If device is specified, then the sample is also moved to that device
        before being returned. This does not apply for the returns lengths,
        which will always remain on the host which PyTorch creates tensors on
        by default.

        Example usage:
        lengths, (bs, ba, br, bsn, bd) = replay.sample_trajectory(n, k)
        for t in range(n):
            for i in range(lengths[t]):
                # Extract a single batch item from this trajectory
                s = bs[t,i]
                a = as[t,i]
                r = br[t,i]
                sn = bsn[t,i]
                d = bd[t,i]
        """
        if self.length == 0:
            raise RuntimeError("Cannot sample from an empty buffer.")
        if k >= self.maxsize:
            raise ValueError("The trajectory length k must be shorter than the maximum replay size.")

        """
        starts = np.random.randint(self.length, size=(n,))
        lengths = torch.zeros((n,), dtype=torch.long)
        all_idxs = np.zeros((n, k), dtype=np.int64)
        for i in range(n):
            idxs = None
            start = starts[i] #np.random.randint(self.length)
            end = start + k
            if end > self.length:
                if self.length != self.maxsize:
                    idxs = np.arange(start, self.length)
                else:
                    idxs = np.concatenate((
                        np.arange(start, self.length),
                        np.arange(0, end - self.length),
                    ))
            else:
                idxs = np.arange(start, end)

            # Check for truncation
            tr_starts = self.trajectory_starts[idxs].nonzero()
            tr_ends = self.trajectory_ends[idxs].nonzero()
            if len(tr_starts) > 0 or len(tr_ends) > 0:
                max_start = -1 if len(tr_starts) == 0 else tr_starts.max()
                max_end   = -1 if len(tr_ends) == 0   else tr_ends.max()

                # Sanity check: this should never be able to happen
                assert max_start != max_end

                if max_start > max_end:
                    # Truncate from start
                    idxs = idxs[max_start:]
                else:
                    # Truncate from end (rare case)
                    end = idxs[max_end] + 1
                    start = end - k
                    if start < 0:
                        idxs = np.concatenate((
                            np.arange(self.length + start, self.length),
                            np.arange(0, end),
                        ))
                    else:
                        idxs = np.arange(start, end)

                    new_tr_starts = self.trajectory_starts[idxs].nonzero()
                    if len(new_tr_starts) > 0:
                        idxs = idxs[new_tr_starts[-1].item():]

            lengths[i] = len(idxs)
            all_idxs[i, :lengths[i]] = idxs
        #"""

        #"""
        starts = np.random.randint(self.length, size=(n,))
        # Don't think we need to consider the non-full case as it should clearly be handled next...

        # This is a batched version of arange
        arange_0_k = torch.ones((n, k), dtype=torch.long) * torch.arange(0, k)
        idxs = arange_0_k + torch.as_tensor(starts.reshape(n, 1))
        idxs = idxs % self.length

        # cumsum(-1).argmax(-1) gets max from the right, since pytorch gets the max from the left
        #print("idxs", idxs.shape, idxs)
        tr_starts = self.trajectory_starts[idxs].cumsum(dim=-1).argmax(dim=-1)
        #print("tr_starts", tr_starts.shape, tr_starts)
        tr_ends = self.trajectory_ends[idxs].cumsum(dim=-1).argmax(dim=-1)
        #print("tr_ends", tr_ends.shape, tr_ends)
        end_gt_start = tr_ends > tr_starts
        #print("end_gt_start", end_gt_start.shape, end_gt_start)

        # CASE 1: The largest found index is the start of a new trajectory
        start_lengths = torch.full((n,), k) - tr_starts
        #print(start_lengths)
        start_mask = arange_0_k < start_lengths.unsqueeze(-1)
        #print(start_mask)
        start_idxs = ((idxs + tr_starts.unsqueeze(-1)) % self.length) * start_mask
        #print(start_idxs)

        # CASE 2: The largest found index is the end of a trajectory
        # NOTE: We might have overwritten a trajectory in the replay buffer
        # with some new data, therefore there could be a scenario in which we
        # sample an index which is the end of a trajectory, but the next index
        # is not the start of a new trajectory. So we need to be vigilant of
        # also checking if the largest index is the end of a trajectory, even
        # though this case is very unlikely to occur.

        # find new beginnings
        #ends = tr_ends + idxs
        # bring down the top indices to become the end of the trajectory [... 44, 45] -> [... 39, 40], etc
        end_shift_idxs = (idxs + self.length - (k - (tr_ends.unsqueeze(-1) + 1))) % self.length
        #print("end_shift_idxs", end_shift_idxs)
        # new start indices
        end_starts = self.trajectory_starts[end_shift_idxs].cumsum(dim=-1).argmax(dim=-1)
        #print("end_starts", end_starts)
        #end_ends = self.trajectory_ends[end_shift_idxs].cumsum(dim=-1).argmax(dim=-1)
        #print("end_ends", end_ends)

        end_lengths = torch.full((n,), k) - end_starts
        end_mask = arange_0_k < end_lengths.unsqueeze(-1)
        end_idxs = ((end_shift_idxs + end_starts.unsqueeze(-1)) % self.length) * end_mask

        #end_lengths = (tr_ends + 1) - tr_starts
        ##print(end_lengths)
        #end_mask = arange_0_k < end_lengths.unsqueeze(-1)
        ##print(end_mask)
        #end_idxs = ((idxs + tr_starts.unsqueeze(-1)) % self.length) * end_mask
        ##print(end_idxs)

        lengths = (end_lengths * end_gt_start) + (start_lengths * end_gt_start.logical_not())
        all_idxs = (end_idxs * end_gt_start.unsqueeze(-1)) + (start_idxs * end_gt_start.logical_not().unsqueeze(-1))
        #"""

        if keys is not None:
            sample = self._getkeys(keys, all_idxs.reshape(n * k), device=device)
        else:
            sample = self._getall(all_idxs.reshape(n * k), device=device)

        if device is not None:
            lengths = lengths.to(device)

        sample = sample._replace(**{
            sk: sv.reshape((n, k) + tuple(sv.shape)[1:])
            for sk, sv in sample._asdict().items()
        })
        return (lengths, sample)


class ReplayMemory(ReplayMemoryBase):
    """A standard replay buffer."""
    ParameterOrder = ReplayMemoryBase.ParameterOrder + [
        "state", "action", "reward", "next_state", "is_done",
    ]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rb["state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["action"] = torch.zeros((self.maxsize,) + self.a_shape, requires_grad=False, dtype=self.a_type)
        self.rb["reward"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.float32)
        self.rb["next_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["is_done"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.float32)


def auto_torch_device(devname):
    if devname.lower() == "auto":
        if torch.cuda.is_available():
            devname = "cuda"
        else:
            devname = "cpu"
    return torch.device(devname)


trainer_arguments = ArgList(
    device = Arg("--device", metavar="DEV", type=auto_torch_device, default=auto_torch_device("auto"),
                help=f"The device to train on."),
    tqdm_freq = Arg("--tqdm-freq", metavar="N", type=at.posint, default=None,
                    _auto_default=False,
                    help=f"How often to tick the tqdm progress bar. (Default: per epoch only)"),
)


class TrainingAlgorithm(object):
    """
    Super class defining common utilities for a training algorithm.
    """
    def __init__(self, env: gym.core.Env = None, args: "trainer_arguments" = None):
        self.__episode_stats = []
        self.__progressbar = None
        self.__current_episode = None
        self.__current_episode_steps = None
        self.__current_episode_total_reward = None
        self.__prev_v_limit = None

        self.__arguments = args
        self.__written_parameters = False

        self.__total_iterations = 0
        self.__iterations_since_tqdm_tick = 0

        self.envspec = None
        if env is not None and env.spec is not None:
            self.envspec = env.spec

        self.logname = None

    def evalcopy(self):
        """
        Returns a dict with necessary argument to reconstruct and evaluable
        shallow copy, disregarding anything that might be required for further
        training. Must be implemented by the subclass to populate necessary
        arguments.

        Intended usage: trainer.save(path, evalcopy=True)
        """
        blob = {
            "evalcopy": True,
            "logname": self.logname,
            "total_iterations": self.__total_iterations,
            "total_episodes": self.__current_episode,
            "envspec": self.envspec,
            "class": self.__class__,
        }
        return blob

    @classmethod
    def from_evalcopy(obj, blob, device=None):
        """
        Contructs a new instance based on a previous eval copy. Must be
        implemented by the subclass.

        Intended usage: Internally in load function
        """
        raise NotImplementedError

    def save(self, path, evalcopy=False):
        path = f"{path}.pkl"

        if evalcopy:
            blob = self.evalcopy()
            LOG.debug(f"Saving evalcopy of {type(self).__name__} to {path}")
        else:
            blob = {
                "evalcopy": False,
                "obj": self,
            }
            LOG.debug(f"Saving full copy of {type(self).__name__} to {path}")

        with open(path, "wb+") as f:
            pickle.dump(blob, f)

    @classmethod
    def load(obj, path, quiet=False, device=None):
        instance = None

        if instance is None:
            try:
                LOG.debug(f"Attempting to load from {path}")
                with open(path, "rb") as f:
                    instance = pickle.load(f)
                LOG.debug(f"Load from {path} successful")
            except FileNotFoundError:
                instance = None

        if instance is None:
            try:
                newpath = f"{path}.pkl"
                LOG.debug(f"Attempting to load from {newpath}")
                with open(newpath, "rb") as f:
                    instance = pickle.load(f)
                LOG.debug(f"Load from {newpath} successful")
            except FileNotFoundError as e:
                raise e # Do not progress beyond this point

        if not isinstance(instance, dict):
            return instance # for backward compatibility

        if instance["evalcopy"]:
            trainer = instance["class"].from_evalcopy(instance, device=device)
            trainer.logname = instance["logname"]
            trainer.envspec = instance["envspec"]
        else:
            trainer = instance["obj"]

        return trainer

    def modules(self):
        # Return all the torch.nn.Module instances for this training algorithm
        raise NotImplementedError

    def set_mode(self, mode : str):
        # Sets the mode to either "train" or "eval" for the PyTorch modules
        if mode.lower() not in {"eval", "train"}:
            raise ValueError(f"invalid mode {mode}")
        for m in self.modules():
            m.train(mode.lower() == "train")

    @property
    def device(self):
        return self.__arguments.device

    def evaluate(self, trajectory_length, env, render=False, seed=None,
                 fn_reset=None, fn_step=None):
        """
        Evaluates the current policy over a single trajectory, returning the
        total accumulated reward and the trajectory length.

        Arguments
        ---------
        trajectory_length : int
            Maximum length of a trajecotry
        env : gymnasium.core.Env
            The environment to evaluate on. For correctness sake, this has to
            be explicitly provided.
        render : bool (optional)
            Whether to render the environment. (Default: True)
        fn_reset : function (optional)
            Whether to have a specific reset function for wrapping purposes.
            (Default: Use env.reset())
        fn_step : function (optional)
            Whether to use a specific function for stepping.
            (Default: Use env.step())
        """
        r_total = 0.0
        l_total = 0
        if fn_reset is None:
            fn_reset = lambda: env.reset(seed=seed)
        if fn_step is None:
            fn_step = lambda a: env.step(a)
        (s, info), done = fn_reset(), False
        for t in range(trajectory_length):
            if done:
                break
            if render:
                env.render()
            a = self.sample_action(s, deterministic=True, with_logprob=False, random=False)
            s_next, r, terminated, truncated, info = fn_step(a)
            done = terminated or truncated
            s = s_next
            r_total += r
            l_total += 1
        #env.close()
        return (r_total, l_total, info)

    def evaluate_many(self, runs=10, **kwargs):
        """Evaluate over many runs and returns mean and std."""
        r_all = np.zeros(shape=(runs,), dtype=np.float32)
        l_all = np.zeros(shape=(runs,), dtype=np.float32)
        delays = {}
        for i in range(runs):
            (r_all[i], l_all[i], evalinfo) = self.evaluate(seed=i, **kwargs)
            if "delay" in evalinfo:
                for d in evalinfo["delay"]["samples"]:
                    delays[str(d)] = delays.get(str(d), 0) + 1

        Result = namedtuple("Result", ["r_all", "r_mean", "r_std",
                                       "l_all", "l_mean", "l_std",
                                       "delays"])
        return Result(r_all, r_all.mean(), r_all.std(),
                      l_all, l_all.mean(), l_all.std(),
                      delays)

    @property
    def total_iterations(self):
        return self.__total_iterations

    def progress(self, n_episodes=None, n_steps=None):
        tqdm_kwargs = {}
        if (n_episodes is None) == (n_steps is None):
            raise ValueError("exactly one of n_episodes and n_steps must be given")
        if n_episodes is not None:
            v_init = 0
            v_limit = n_episodes
            v_incr = lambda self, v: v + 1
        elif n_steps is not None:
            v_init = self.__total_iterations
            if self.__prev_v_limit is not None:
                v_limit = self.__prev_v_limit + n_steps
            else:
                v_limit = v_init + n_steps
            self.__prev_v_limit = v_limit
            v_incr = lambda self, v: max(self.__total_iterations, v + 1)
            tqdm_kwargs["total"] = v_limit - v_init

        if not self.__written_parameters:
            log_write_text("parsed arguments", str(self.__arguments))
            log_write_text("system arguments", str(sys.argv))
            self.__written_parameters = True

        self.__progressbar = tqdm.tqdm(**tqdm_kwargs)
        self.__iterations_since_tqdm_tick = 0

        v = v_init
        while v < v_limit:
            ep = len(self.__episode_stats) + 1
            self.__current_episode = ep
            self.__current_episode_steps = None
            self.__current_episode_total_reward = None
            self.__episode_stats.append(dict())
            yield ep
            self.__episode_stats[ep-1]["steps"] = self.__current_episode_steps
            self.__episode_stats[ep-1]["total_reward"] = self.__current_episode_total_reward
            # update progress bar
            pbar_parts = [f"Episode: {self.__current_episode}"]
            if self.__current_episode_total_reward is not None:
                pbar_parts.append(f"Total Reward: {self.__current_episode_total_reward:.2f}")
                # calculate running average
                running_average_values = [stat["total_reward"] for stat in self.__episode_stats[-50:] if stat["total_reward"] is not None]
                running_average = sum(running_average_values) / len(running_average_values)
                pbar_parts.append(f"Running Avg: {running_average:.2f}")
            self.__progressbar.set_postfix_str(" | ".join(pbar_parts))
            # Write episode stats to log
            for name, value in self.__episode_stats[ep-1].items():
                log_write_scalaridx(f"episode/{name}", value, self.__total_iterations)
            log_write_flush()
            v_next = v_incr(self, v)
            if self.__iterations_since_tqdm_tick > 0:
                self.__progressbar.update(self.__iterations_since_tqdm_tick)
                self.__iterations_since_tqdm_tick = 0
            v = v_next

        self.__progressbar.close()
        self.__progressbar = None
        self.__iterations_since_tqdm_tick = 0
        self.__current_episode = None
        self.__current_episode_steps = None
        self.__current_episode_total_reward = None

    def progress_steps(self, limit=None):
        self.__current_episode_steps = 0
        while True:
            self.__total_iterations += 1
            self.__current_episode_steps += 1
            self.__iterations_since_tqdm_tick += 1
            if self.__arguments.tqdm_freq is not None and self.__arguments.tqdm_freq <= self.__iterations_since_tqdm_tick:
                self.__progressbar.update(self.__iterations_since_tqdm_tick)
                self.__iterations_since_tqdm_tick = 0
            if limit is not None and self.__current_episode_steps >= limit:
                break
            yield self.__current_episode_steps

    def record_episode_total_reward(self, total_reward):
        self.__current_episode_total_reward = total_reward

    def record_iteration_value(self, name, value):
        log_write_scalaridx(f"iterations/{name}", value, self.__total_iterations)

    def record_iteration_histogram(self, name, value):
        log_write_histogramidx(f"iterations-histogram/{name}", value, self.__total_iterations)

    def record_iteration_gradients(self, name, torch_module):
        for pname, p in torch_module.named_parameters():
            log_write_histogramidx(f"iterations-gradient/{name}/{pname}", p.grad, self.__total_iterations)

    def record_evaluation_value(self, name, value):
        log_write_scalaridx(f"evaluation/{name}", value, self.__total_iterations)

    def record_evaluation_action(self, name, action):
        ait = np.nditer(action, flags=["multi_index"])
        for aval in ait:
            idx = ("/" + ",".join(str(i) for i in ait.multi_index)) if ait.multi_index != () else ""
            self.record_iteration_value(f"{name}/action/{idx}", aval)


class DeterministicTrainingAlgorithm(TrainingAlgorithm):
    def __init__(self, *args, env : gym.core.Env = None, **kwargs):
        super().__init__(*args, env=env, **kwargs)
        self.__action_space = env.action_space

    def get_noise(self, shape=None):
        raise NotImplementedError(f"Class {type(self).__name__} does not have get_noise implemented, and cannot sample noisy actions.")

    @torch.no_grad()
    def sample_action(self, *inputs,
                      random=False,
                      add_noise=True,
                      **kwargs):
        if random:
            a = self.__action_space.sample()
        else:
            torch_inputs = tuple(
                torch.as_tensor(ipt, dtype=torch.float32, device=self.device).unsqueeze(0)
                for ipt in inputs
            )
            a_gen = self.pi(*torch_inputs).cpu()
            a = a_gen.cpu().numpy()[0]

        if add_noise:
            a += self.get_noise()

        return np.clip(a, self.__action_space.low, self.__action_space.high)


class StochasticTrainingAlgorithm(TrainingAlgorithm):
    def __init__(self, *args, env : gym.core.Env = None, **kwargs):
        super().__init__(*args, env=env, **kwargs)
        self.__action_space = env.action_space

    @torch.no_grad()
    def sample_action(self, *inputs,
                      random=False,
                      deterministic=False,
                      with_raw_sample=False,
                      with_logprob=False,
                      with_info=False,
                      **kwargs):
        squeeze = False
        if random:
            a = self.__action_space.sample()
            if with_raw_sample:
                raw_a = a
            if with_logprob:
                logp_a = None
            info = {}
        else:
            torch_inputs = tuple(
                torch.as_tensor(ipt, dtype=torch.float32, device=self.device).unsqueeze(0)
                for ipt in inputs
            )
            a_gen, raw_a_gen, logp_a_gen, info = self.pi(*torch_inputs,
                                                         deterministic=deterministic,
                                                         with_logprob=with_logprob,
                                                         with_info=with_info)

            a = a_gen.cpu().numpy()
            if isinstance(self.__action_space, gym.spaces.Box):
                # Make sure that the generated action is within bounds
                a = np.clip(a, self.__action_space.low, self.__action_space.high)

            squeeze = bool(len(a.shape) > 0)
            if with_raw_sample:
                raw_a = raw_a_gen.cpu().numpy()
            if with_logprob:
                logp_a = logp_a_gen.cpu().numpy()

            info = copy.copy(info)
            for k in list(info.keys()):
                if isinstance(info[k], torch.Tensor):
                    info[k] = info[k].cpu().numpy()
                    if squeeze:
                        info[k] = info[k].squeeze(0)

        ret = (a,)
        if with_raw_sample:
            ret += (raw_a,)
        if with_logprob:
            ret += (logp_a,)

        if squeeze:
            ret = tuple(v.squeeze(0) for v in ret)

        if with_info:
            ret += (info,)

        return ret if len(ret) > 1 else ret[0]
